Setup and Hardware Configuration¶

In [1]:
'''
For Running in local setup (CUDA 12.9 RTX 16GB GPU),
run in LINUX CLI and copy the URL to colab local runtime option

jupyter notebook --no-browser --ip=127.0.0.1 --port=8888 \
  --ServerApp.websocket_ping_interval=36000 \
  --ServerApp.websocket_ping_timeout=0

'''
# pip install nbconvert
# !ls
# !jupyter nbconvert colab18Sep2359b.ipynb --to html
# # if running in drive, uncomment
# import os
# os.getcwd()
# %cd '/content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising'
Out[1]:
'\nFor Running in local setup (CUDA 12.9 RTX 16GB GPU),\nrun in LINUX CLI and copy the URL to colab local runtime option\n\njupyter notebook --no-browser --ip=127.0.0.1 --port=8888   --ServerApp.websocket_ping_interval=36000   --ServerApp.websocket_ping_timeout=0\n\n'
In [2]:
# Setup For CT physics

import numpy as np
import torch
print(torch.__version__, '# torch')
print(torch.version.cuda,'# cuda')
print(torch.cuda.get_arch_list() , '# cuda arch')
# !pip install astra-toolbox
import astra
print(astra.__version__, '# astra')
print(astra.get_gpu_info())
print('cuda available,', torch.cuda.is_available() ) # test pytorch is functioning with cuda
# !pip install odl
import odl
print(odl.__version__, '# odl')

# NOTE # PLACE IN A SEPERATE SECTION with desc
from dival.reconstructors.odl_reconstructors import FBPReconstructor
2.8.0+cu129 # torch
12.9 # cuda
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120'] # cuda arch
2.3.1 # astra
GPU #0: NVIDIA GeForce RTX 5060 Ti, with 16310MB, CUDA compute capability 12.0
cuda available, True
0.8.2 # odl
2025-09-19 23:28:19.439486: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-19 23:28:19.629398: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1758320899.704631     402 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758320899.727147     402 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758320899.901104     402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758320899.901130     402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758320899.901132     402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758320899.901133     402 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-09-19 23:28:19.920752: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
In [3]:
# Dataset Manipulation utilities

# !pip install dival
from dival import get_standard_dataset
from dival.util.plot import plot_images
from dival.data import DataPairs

# For operator discretization
from dival.util.odl_utility import uniform_discr_element
In [4]:
# for CNN Architectures

import time
from datetime import datetime
from collections import OrderedDict
# !pip install hdf5storage
import hdf5storage
import os
import os.path as osp
import sys
import matplotlib.pyplot as plt
In [5]:
# for Gaussian Denoisers on CNN

# import pip
# !pip install opencv-python
# import cv2 # prerequisite for utils_image
import logging
In [6]:
# for Transformer architecture

import numpy as np
import os
import argparse
from tqdm import tqdm
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
from runpy import run_path
# !pip install opencv-python
import cv2

import os
# !pip install einops
import einops
import shutil

import torch.nn as nn
import torch
import torch.nn.functional as F
In [7]:
# For Transformer training

# !pip install pyyaml
import yaml
import argparse
import random
from pathlib import Path
import pprint
# !pip install lmdb
from pdb import set_trace as stx
import yaml
In [8]:
# For Evaluations

from dival.evaluation import TaskTable
from dival.measure import PSNR
from dival.measure import SSIM
from dival.measure import L2
In [9]:
# another import with clashing names
# Importing utils and utils in order

%cd
if os.getcwd() != '/home/hiran':
  raise ImportError
from utils import utils_logger
from utils import utils_model
from utils import utils_image as utilsImg

from utils import utils_deblur
from utils import utils_pnp as pnp
from utils import utils_sisr_beforepytorchversion8 as sr # utils_sisr as sr # deprecated library

if utilsImg.__file__ != '/home/hiran/utils/utils_image.py':
  raise ImportError

import Restormer.Denoising.utils as utilsDn
if utilsDn.__file__ != '/home/hiran/Restormer/Denoising/utils.py':
  raise ImportError

%cd Restormer/Denoising
if os.getcwd() != '/home/hiran/Restormer/Denoising':
  raise ImportError

sys.path.append('/home/hiran/Restormer')
from basicsr.models.archs.restormer_arch import Restormer
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.
  bkms = self.shell.db.get('bookmarks', {})
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:428: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
/home/hiran
/home/hiran/Restormer/Denoising
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
In [10]:
# other micallaneous imports

# from dival import Reconstructor
In [11]:
# import pip
# !pip install hdf5storage
# !pip uninstall torch torchvision torchaudio -y
# !pip install torch torchvision --index-url https://download.pytorch.org/whl/cu129
# !pip uninstall odl -y
# !pip install git+https://github.com/odlgroup/odl.git #code to installing night build 1 since AVOID_UNNECESSARY_COPY issue in example DeepImagePrior, but finally issue fixed via odlt.AVOID_UNNECESSARY_COPY = False - below
# pip install --force-reinstall git+https://github.com/odlgroup/odl.git@master
# pip install odl --no-cache-dir

Data Loading and Preprocessing¶

In [12]:
# ellipses data
dataset_ellipses = get_standard_dataset('ellipses', impl='astra_cuda')
test_data_ellipses = dataset_ellipses.get_data_pairs('test', 10)
# test_data_ellipses_all = dataset_ellipses.get_data_pairs('test')

# lodopab data (70K CT lq, gt images - 106GB)
IMPL = 'astra_cuda'
dataset = get_standard_dataset('lodopab', impl=IMPL) # on disk, not RAM
test_data_2 = dataset.get_data_pairs('test', 2)
# test_data_50 = dataset.get_data_pairs('test', 50)
test_data_10 = dataset.get_data_pairs('test', 10)
# test_data_30 = dataset.get_data_pairs('test', 30)
test_data = dataset.get_data_pairs('test', 256)

# test_data_all = dataset.get_data_pairs('test', 1000)
# train_ds = dataset.get_data_pairs('train', 6000) # loads 6000 lq, gt pairs - 13GB RAM
train_2 = dataset.get_data_pairs('train', 2)
# validation_data = dataset.get_data_pairs('validation') # loads 3522 validation pairs - 13GB RAM

# del test_data_all # save RAM
# del train_ds # save RAM
# del validation_data # save RAM
# del test_data_ellipses # save RAM

CT Physics - Inverse Radon Transformation and Back Projections¶

Defining Discrete Reconstruction Spaces and Radon Transformations for the 3 Datasets¶

In [13]:
# ray_trafo = dataset.get_ray_trafo(impl=IMPL)
ray_trafo_lodopab = dataset.get_ray_trafo(impl=IMPL)
ray_trafo_ellipses = dataset_ellipses.get_ray_trafo(impl=IMPL)

# Sinogram projection spaces

reco_space_lodopab = ray_trafo_lodopab.domain # .zero()
reco_space_ellipses = ray_trafo_ellipses.domain # .zero()
reco_space_shepp = odl.uniform_discr(
    min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300],
    dtype='float32') # .zero() # x,y dimension min & max # grid lines

geometry = odl.tomo.cone_beam_geometry(reco_space_shepp, 40, 40, 360) # build our ct machine geometry using odl # object = human cross section space, source = ray emmiter radius from origin = human , likewise radius from origin to detect, optional no.of angles in our geometry

# radon transform function ( build sinogram from a ct scan )
ray_trafo_shepp = odl.tomo.RayTransform(reco_space_shepp, geometry, impl=IMPL)

# setting up shepp_logan phantom as a testing platform
phantom = odl.phantom.shepp_logan(reco_space_shepp, modified=True) # import standard scientific sample ct named shpp logan. that's our gt
ground_truth = phantom
proj_data = ray_trafo_shepp(phantom) # call the function to build the sinogram
observation = (proj_data + np.random.poisson(0.3, proj_data.shape)).asarray()
test_data_shepp = DataPairs(observation, ground_truth, name='shepp-logan + pois')

Filtered Back Projection¶

In [14]:
# back projection model original version https://odlgroup.github.io/odl/

from dival.reference_reconstructors import (
    check_for_params, download_params, get_params_path)

from dival.reconstructors import Reconstructor, IterativeReconstructor
from dival.reconstructors.odl_reconstructors import FBPReconstructor

reconstructor_lodopab = FBPReconstructor(dataset.get_ray_trafo(impl=IMPL))
reconstructor_ellipses = FBPReconstructor(dataset_ellipses.ray_trafo)
reconstructor_shepp = FBPReconstructor(ray_trafo_shepp)
In [ ]:
#  FBP model: using odl lib, mapping projection into a given reconstruction space

class FBPReconstructor_demo(Reconstructor):
    HYPER_PARAMS = {
        'filter_type':
            {'default': 'Ram-Lak',
             'choices': ['Ram-Lak', 'Shepp-Logan', 'Cosine', 'Hamming',
                         'Hann']},
        'frequency_scaling':
            {'default': 1.,
             'range': [0, 1],
             'grid_search_options': {'num_samples': 11}}
    }

    """Reconstructor applying filtered back-projection.

    Attributes
    ----------
    fbp_op : `odl.operator.Operator`
        The operator applying filtered back-projection.
        It is computed in the constructor, and is recomputed for each
        reconstruction if ``recompute_fbp_op == True`` (since parameters could
        change).
    """
    def __init__(self, ray_trafo, padding=True, hyper_params=None,
                 pre_processor=None, post_processor=None,
                 recompute_fbp_op=True, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator. See `odl.tomo.fbp_op` for details.
        padding : bool, optional
            Whether to use padding (the default is ``True``).
            See `odl.tomo.fbp_op` for details.
        pre_processor : callable, optional
            Callable that takes the observation and returns the sinogram that
            is passed to the filtered back-projection operator.
        post_processor : callable, optional
            Callable that takes the filtered back-projection and returns the
            final reconstruction.
        recompute_fbp_op : bool, optional
            Whether :attr:`fbp_op` should be recomputed on each call to
            :meth:`reconstruct`. Must be ``True`` (default) if changes to
            :attr:`ray_trafo`, :attr:`hyper_params` or :attr:`padding` are
            planned in order to use the updated values in :meth:`reconstruct`.
            If none of these attributes will change, you may specify
            ``recompute_fbp_op==False``, so :attr:`fbp_op` can be computed
            only once, improving reconstruction time efficiency.
        """
        self.ray_trafo = ray_trafo
        self.padding = padding
        self.pre_processor = pre_processor
        self.post_processor = post_processor
        super().__init__(
            reco_space=ray_trafo.domain, observation_space=ray_trafo.range,
            hyper_params=hyper_params, **kwargs)
        self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding,
                             **self.hyper_params)
        self.recompute_fbp_op = recompute_fbp_op

    def _reconstruct(self, observation, out):
        if self.pre_processor is not None:
            observation = self.pre_processor(observation)
        if self.recompute_fbp_op:
            self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding,
                                 **self.hyper_params)
        if out in self.reco_space:
            self.fbp_op(observation, out=out)
        else:  # out is e.g. numpy array, cannot be passed to fbp_op
            out[:] = self.fbp_op(observation)
        if self.post_processor is not None:
            out[:] = self.post_processor(out)

Hyper Parameter tuning in Filtered Back Projections¶

In [ ]:
np.random.seed(0)

reconstructor_ = reconstructor_ellipses
# reconstructor_ = reconstructor_lodopab

test_data_ = test_data_ellipses
# test_data_ = test_data_10

# %% task table and reconstructors
eval_tt = TaskTable()

eval_tt.append(reconstructor=reconstructor_, measures=[PSNR, SSIM],
               test_data=test_data_,
               hyper_param_choices={'filter_type': ['Ram-Lak', 'Hann'],
                                    'frequency_scaling': [0.8, 0.9, 1.]})

# %% run task table
results = eval_tt.run()
print(results.to_string(show_columns=['misc']))  # best param : Hann 0.8

# %% plot reconstructions
fig = results.plot_all_reconstructions(test_ind=range(1),
                                       fig_size=(9, 4), vrange='individual')
running task 0/1 ...
sub-task 0/6 ...
sub-task 1/6 ...
sub-task 2/6 ...
sub-task 3/6 ...
sub-task 4/6 ...
sub-task 5/6 ...
ResultTable(results=
                          reconstructor       test_data                     measure_values                                                                 misc
task_ind sub_task_ind                                                                                                                                          
0        0             FBPReconstructor  test part 0:10  mean: {psnr: 22.16, ssim: 0.4689}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.8}}
         1             FBPReconstructor  test part 0:10  mean: {psnr: 21.78, ssim: 0.4536}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.9}}
         2             FBPReconstructor  test part 0:10  mean: {psnr: 21.36, ssim: 0.4381}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 1.0}}
         3             FBPReconstructor  test part 0:10  mean: {psnr: 24.57, ssim: 0.5963}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.8}}
         4             FBPReconstructor  test part 0:10  mean: {psnr: 24.47, ssim: 0.5831}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.9}}
         5             FBPReconstructor  test part 0:10  mean: {psnr: 24.33, ssim: 0.5705}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 1.0}}
)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# %% task table and reconstructors
eval_tt = TaskTable()

eval_tt.append(reconstructor=reconstructor_lodopab, measures=[PSNR, SSIM],
               test_data=test_data_10,
               hyper_param_choices={'filter_type': ['Ram-Lak', 'Hann'],
                                    'frequency_scaling': [0.8, 0.9, 1.]})

# %% run task table
results = eval_tt.run()
print(results.to_string(show_columns=['misc'])) # best param : Hann 0.8

# %% plot reconstructions
fig = results.plot_all_reconstructions(test_ind=range(1),
                                       fig_size=(9, 4), vrange='individual')
running task 0/1 ...
sub-task 0/6 ...
sub-task 1/6 ...
sub-task 2/6 ...
sub-task 3/6 ...
sub-task 4/6 ...
sub-task 5/6 ...
ResultTable(results=
                          reconstructor       test_data                     measure_values                                                                 misc
task_ind sub_task_ind                                                                                                                                          
0        0             FBPReconstructor  test part 0:10  mean: {psnr: 27.08, ssim: 0.5211}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.8}}
         1             FBPReconstructor  test part 0:10   mean: {psnr: 26.23, ssim: 0.487}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.9}}
         2             FBPReconstructor  test part 0:10   mean: {psnr: 25.4, ssim: 0.4552}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 1.0}}
         3             FBPReconstructor  test part 0:10  mean: {psnr: 31.22, ssim: 0.7181}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.8}}
         4             FBPReconstructor  test part 0:10  mean: {psnr: 31.05, ssim: 0.7024}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.9}}
         5             FBPReconstructor  test part 0:10  mean: {psnr: 30.78, ssim: 0.6856}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 1.0}}
)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# %% task table and reconstructors
eval_tt = TaskTable()

eval_tt.append(reconstructor=reconstructor_shepp, measures=[PSNR, SSIM],
               test_data=test_data_shepp,
               hyper_param_choices={'filter_type': ['Ram-Lak', 'Hann'],
                                    'frequency_scaling': [0.8, 0.9, 1.]})

# %% run task table
results = eval_tt.run()
print(results.to_string(show_columns=['misc']))

# %% plot reconstructions
fig = results.plot_all_reconstructions(test_ind=range(1),
                                       fig_size=(9, 4), vrange='individual')
running task 0/1 ...
sub-task 0/6 ...
sub-task 1/6 ...
sub-task 2/6 ...
sub-task 3/6 ...
sub-task 4/6 ...
sub-task 5/6 ...
ResultTable(results=
                          reconstructor           test_data                      measure_values                                                                 misc
task_ind sub_task_ind                                                                                                                                               
0        0             FBPReconstructor  shepp-logan + pois   mean: {psnr: 13.87, ssim: 0.1107}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.8}}
         1             FBPReconstructor  shepp-logan + pois   mean: {psnr: 12.84, ssim: 0.1026}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 0.9}}
         2             FBPReconstructor  shepp-logan + pois  mean: {psnr: 11.89, ssim: 0.09516}  {'hp_choice': {'filter_type': 'Ram-Lak', 'frequency_scaling': 1.0}}
         3             FBPReconstructor  shepp-logan + pois   mean: {psnr: 21.39, ssim: 0.2231}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.8}}
         4             FBPReconstructor  shepp-logan + pois   mean: {psnr: 20.51, ssim: 0.1985}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 0.9}}
         5             FBPReconstructor  shepp-logan + pois   mean: {psnr: 19.66, ssim: 0.1801}     {'hp_choice': {'filter_type': 'Hann', 'frequency_scaling': 1.0}}
)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

EDA - Exploratory Data Analysis¶

In [ ]:
def plot_ctrecn(test_data, recos2, psnrs_= None, visuals = 2):
  """takes in a datasset for gt, recos_ for obs, psnrs_ and visualize many images as defined."""
  print('whole eval mean psnr: {:f}'.format(np.mean(psnrs_))) if psnrs_ != None else None
  for i in range(visuals):
      _, ax = plot_images([recos2[i], test_data.ground_truth[i]],
                          fig_size=(10, 4))
      ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs_[i])) if psnrs_ != None else None
      ax[0].set_title('Reconstruction')
      ax[1].set_title('ground truth')
      ax[0].figure.suptitle('test sample {:d}'.format(i))
In [ ]:
recos2 = []
# del reconstructor
# borrowing the reconstructor model for the moment
reconstructor = reconstructor_lodopab # (
    # ray_trafo_lodopab , hyper_params={'filter_type': 'Ram-Lak','frequency_scaling': 1.0}

with torch.no_grad(): # save memory by not calculating gradient
  for obs, gt in test_data_10:
      torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
      torch.cuda.empty_cache() # clear cache while iteration
      reco = reconstructor_lodopab.reconstruct(obs) # return odl elem H,W with normalized pixel vals
      recos2.append(reco)
In [ ]:
plot_ctrecn(test_data_10, recos2,  visuals = 10)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Experiment - simple Network¶

In [ ]:
# test cuda is running
x = torch.rand(5, 3)
print(x, '\n# test pytorch is functioning with cuda')
tensor([[0.9525, 0.6503, 0.9497],
        [0.2197, 0.3869, 0.6673],
        [0.5443, 0.8283, 0.7626],
        [0.9766, 0.0579, 0.6842],
        [0.3690, 0.5183, 0.7585]]) 
# test pytorch is functioning with cuda
In [ ]:
# from dival.util.odl_utility import uniform_discr_element
print( uniform_discr_element([0, 1, 2, 3, 4, 5, 6] ) )
[ 0.,  1.,  2., ...,  4.,  5.,  6.]
In [ ]:
uniform_discr_element([0, 1.5, 2.5, 3, 4, 5, 6] )
Out[ ]:
uniform_discr(-3.5, 3.5, 7).element(
    [ 0. ,  1.5,  2.5, ...,  4. ,  5. ,  6. ]
)
In [ ]:
L = np.random.rand(15)
L_discrete = uniform_discr_element(L)
print( L_discrete)
[ 0.5488135 ,  0.71518937,  0.60276338, ...,  0.56804456,  0.92559664,
  0.07103606]
In [ ]:
type( uniform_discr_element(L) )
Out[ ]:
odl.discr.discr_space.DiscretizedSpaceElement
In [ ]:
np.random.seed(1)

ground_truth = uniform_discr_element([0, 1, 2, 3, 4, 5, 6])
observation = ground_truth + 1
observation += np.random.normal(size=observation.shape)
test_data = DataPairs(observation, ground_truth, name='x + 1 + normal')
eval_tt = TaskTable()
In [ ]:
class MinusOneReconstructor(Reconstructor):
    def reconstruct(self, observation):
        return observation - 1


reconstructor = MinusOneReconstructor(name='y-1')
eval_tt.append(reconstructor=reconstructor, test_data=test_data,
               measures=[L2])
results = eval_tt.run()
results.plot_reconstruction(0)
print(results)
running task 0/1 ...
ResultTable(results=
                      reconstructor       test_data     measure_values
task_ind sub_task_ind                                                 
0        0                      y-1  x + 1 + normal  mean: {l2: 3.679}
)
No description has been provided for this image

Experiment - Inferencing using published Networks and shepp-logan phantom CT scan¶

In [29]:
from dival.reconstructors.odl_reconstructors import (FBPReconstructor,
                                                     CGReconstructor,
                                                     GaussNewtonReconstructor,
                                                     LandweberReconstructor,
                                                     MLEMReconstructor,
                                                     ISTAReconstructor,
                                                     PDHGReconstructor,
                                                     DouglasRachfordReconstructor,
                                                     ForwardBackwardReconstructor,
                                                     ADMMReconstructor,
                                                     BFGSReconstructor)
In [ ]:
# print( np.random.seed(0) )
# odl.tomo.cone_beam_geometry?
# DataPairs?
# TaskTable?
In [ ]:
np.random.seed(0) # to make pois noise in obs is consistent throughout reconstructors; fair evaluation hence.


# %% task table and reconstructors
eval_tt = TaskTable()
fbp_reconstructor = FBPReconstructor(ray_trafo_shepp)
cg_reconstructor = CGReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo_shepp, 0.5*reco_space_shepp.one(), 1)
ista_reconstructor = ISTAReconstructor(ray_trafo_shepp,reco_space_shepp.zero(), 10) # works
pdhg_reconstructor = PDHGReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 10) # operand issue
dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo_shepp,
                                                      reco_space_shepp.zero(), 10) # operand issue
forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo_shepp,
                                                      reco_space_shepp.zero(), 10) # operand issue
admm_reconstructor = ADMMReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 10) # works
bfgs_reconstructor = BFGSReconstructor(ray_trafo_shepp, reco_space_shepp.zero(), 10) # works

reconstructors = [fbp_reconstructor, cg_reconstructor, gn_reconstructor,
                  lw_reconstructor, mlem_reconstructor , ista_reconstructor , admm_reconstructor, bfgs_reconstructor] #,   pdhg_reconstructor, dougrach_reconstructor ,forwardbackward_reconstructor ]
                  # removed at 6-8 due to unsupported operand type(s) for +: 'MultiplyOperator' and 'DiscretizedSpaceElement'

options = {'save_iterates': True}

eval_tt.append_all_combinations(reconstructors=reconstructors,
                                test_data=[test_data_shepp], options=options) # original

# testing one reconstructor
# eval_tt.append_all_combinations(reconstructors=[dougrach_reconstructor],
#                                 test_data=[test_data_shepp], options=options)

# %% run task table
results = eval_tt.run()
results.apply_measures([PSNR, SSIM])
print(results)

# %% plot reconstructions
fig = results.plot_all_reconstructions(fig_size=(9, 4), vrange='individual')

# %% plot convergence of CG # comment out if testing one reconstructor
results.plot_convergence(1, fig_size=(9, 6), gridspec_kw={'hspace': 0.5})

# %% plot performance
results.plot_performance(PSNR, figsize=(10, 4))
running task 0/8 ...
running task 1/8 ...
running task 2/8 ...
running task 3/8 ...
running task 4/8 ...
running task 5/8 ...
running task 6/8 ...
running task 7/8 ...
ResultTable(results=
                                  reconstructor           test_data                      measure_values
task_ind sub_task_ind                                                                                  
0        0                     FBPReconstructor  shepp-logan + pois  mean: {psnr: 11.89, ssim: 0.09516}
1        0                      CGReconstructor  shepp-logan + pois   mean: {psnr: 17.73, ssim: 0.5474}
2        0             GaussNewtonReconstructor  shepp-logan + pois    mean: {psnr: 21.2, ssim: 0.4112}
3        0               LandweberReconstructor  shepp-logan + pois   mean: {psnr: 16.52, ssim: 0.3932}
4        0                    MLEMReconstructor  shepp-logan + pois   mean: {psnr: 14.26, ssim: 0.3239}
5        0                    ISTAReconstructor  shepp-logan + pois   mean: {psnr: 18.84, ssim: 0.5827}
6        0                    ADMMReconstructor  shepp-logan + pois   mean: {psnr: 20.73, ssim: 0.4751}
7        0                    BFGSReconstructor  shepp-logan + pois    mean: {psnr: 22.3, ssim: 0.3976}
)
Out[ ]:
<Axes: title={'center': 'peak signal-to-noise ratio'}>
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [15]:
def inference_by_dataset_model (dataset_ = test_data_shepp, ray_trafo_=ray_trafo_shepp, reco_space_=reco_space_shepp):
  np.random.seed(0) # to make pois noise in obs is consistent throughout reconstructors; fair evaluation hence.
  # %% task table and reconstructors
  eval_tt = TaskTable()
  fbp_reconstructor = FBPReconstructor(ray_trafo_)
  cg_reconstructor = CGReconstructor(ray_trafo_, reco_space_.zero(), 4)
  gn_reconstructor = GaussNewtonReconstructor(ray_trafo_, reco_space_.zero(), 2)
  lw_reconstructor = LandweberReconstructor(ray_trafo_, reco_space_.zero(), 8)
  mlem_reconstructor = MLEMReconstructor(ray_trafo_, 0.5*reco_space_.one(), 1)
  ista_reconstructor = ISTAReconstructor(ray_trafo_,reco_space_.zero(), 10) # works
  pdhg_reconstructor = PDHGReconstructor(ray_trafo_, reco_space_.zero(), 10) # operand issue
  dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo_,
                                                        reco_space_.zero(), 10) # operand issue
  forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo_,
                                                        reco_space_.zero(), 10) # operand issue
  admm_reconstructor = ADMMReconstructor(ray_trafo_, reco_space_.zero(), 10) # works
  bfgs_reconstructor = BFGSReconstructor(ray_trafo_, reco_space_.zero(), 10) # works

  reconstructors = [fbp_reconstructor, cg_reconstructor, gn_reconstructor,
                    lw_reconstructor, mlem_reconstructor , ista_reconstructor , admm_reconstructor, bfgs_reconstructor] #,   pdhg_reconstructor, dougrach_reconstructor ,forwardbackward_reconstructor ]
                    # removed at 6-8 due to unsupported operand type(s) for +: 'MultiplyOperator' and 'DiscretizedSpaceElement'

  options = {'save_iterates': True, 'training': False}

  eval_tt.append_all_combinations(reconstructors=reconstructors,
                                  test_data= [dataset_], options=options) # original


  # %% run task table
  results = eval_tt.run()
  results.apply_measures([PSNR, SSIM])
  print(results)

  # # %% plot reconstructions
  # fig = results.plot_all_reconstructions(fig_size=(9, 4), vrange='individual')

  # # %% plot convergence of CG # comment out if testing one reconstructor
  # results.plot_convergence(1, fig_size=(9, 6), gridspec_kw={'hspace': 0.5})

  # # %% plot performance
  # results.plot_performance(PSNR, figsize=(10, 4))
In [ ]:
# inference_by_dataset_model()
running task 0/8 ...
running task 1/8 ...
running task 2/8 ...
running task 3/8 ...
running task 4/8 ...
running task 5/8 ...
running task 6/8 ...
running task 7/8 ...
ResultTable(results=
                                  reconstructor           test_data                      measure_values
task_ind sub_task_ind                                                                                  
0        0                     FBPReconstructor  shepp-logan + pois  mean: {psnr: 11.89, ssim: 0.09516}
1        0                      CGReconstructor  shepp-logan + pois   mean: {psnr: 17.73, ssim: 0.5474}
2        0             GaussNewtonReconstructor  shepp-logan + pois    mean: {psnr: 21.2, ssim: 0.4112}
3        0               LandweberReconstructor  shepp-logan + pois   mean: {psnr: 16.52, ssim: 0.3932}
4        0                    MLEMReconstructor  shepp-logan + pois   mean: {psnr: 14.26, ssim: 0.3239}
5        0                    ISTAReconstructor  shepp-logan + pois   mean: {psnr: 18.84, ssim: 0.5827}
6        0                    ADMMReconstructor  shepp-logan + pois   mean: {psnr: 20.73, ssim: 0.4751}
7        0                    BFGSReconstructor  shepp-logan + pois    mean: {psnr: 22.3, ssim: 0.3976}
)
In [16]:
def inference_by_mult_datasets_models( inf_published = True, oth_recons=[], dataset_list = [dataset_ellipses, dataset], test_data_list = [test_data_ellipses, test_data_10]):

  for i in range(len(dataset_list)):

    def inference_by_dataset_model_core (test_data_ = test_data_list[i], ray_trafo_=dataset_list[i].get_ray_trafo(impl=IMPL), reco_space_=dataset_list[i].get_ray_trafo(impl=IMPL).domain):

      np.random.seed(0) # to make pois noise in obs is consistent throughout reconstructors; fair evaluation hence.
      # %% task table

      eval_tt = TaskTable()

      fbp_reconstructor = FBPReconstructor(ray_trafo_)
      cg_reconstructor = CGReconstructor(ray_trafo_, reco_space_.zero(), 4)
      gn_reconstructor = GaussNewtonReconstructor(ray_trafo_, reco_space_.zero(), 2)
      lw_reconstructor = LandweberReconstructor(ray_trafo_, reco_space_.zero(), 8)
      mlem_reconstructor = MLEMReconstructor(ray_trafo_, 0.5*reco_space_.one(), 1)
      ista_reconstructor = ISTAReconstructor(ray_trafo_,reco_space_.zero(), 10) # works
      pdhg_reconstructor = PDHGReconstructor(ray_trafo_, reco_space_.zero(), 10) # operand issue
      dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo_,
                                                            reco_space_.zero(), 10) # operand issue
      forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo_,
                                                            reco_space_.zero(), 10) # operand issue
      admm_reconstructor = ADMMReconstructor(ray_trafo_, reco_space_.zero(), 10) # works
      bfgs_reconstructor = BFGSReconstructor(ray_trafo_, reco_space_.zero(), 10) # works

      options = {'save_iterates': True, 'skip_training': True}

      # admm_reconstructor, ,  bfgs_reconstructor - runs into a loop when run for whole dataset
      recons_published = [ fbp_reconstructor, gn_reconstructor, ista_reconstructor, cg_reconstructor, lw_reconstructor, mlem_reconstructor, admm_reconstructor ,  bfgs_reconstructor ] #,  admm_reconstructor ,  bfgs_reconstructor - runs into OOM when run for whole dataset # cg_reconstructor, lw_reconstructor, mlem_reconstructor , OOM only when run for lodopab whole # pdhg_reconstructor, dougrach_reconstructor ,forwardbackward_reconstructor was
                        # removed at 6-8 due to unsupported operand type(s) for +: 'MultiplyOperator' and 'DiscretizedSpaceElement' # working w/ low mem - fbp_reconstructor, gn_reconstructor,

      recons=[]
      if inf_published:
        recons=recons_published
      recons = recons+oth_recons

      eval_tt.append_all_combinations(reconstructors=recons,
                                      test_data= [test_data_], options=options)


      # %% run task table
      results = eval_tt.run()
      results.apply_measures([PSNR, SSIM])
      print(results)

      # # %% plot reconstructions
      # fig = results.plot_all_reconstructions(fig_size=(9, 4), vrange='individual')

      # # %% plot convergence of CG # comment out if testing one reconstructor
      # results.plot_convergence(1, fig_size=(9, 6), gridspec_kw={'hspace': 0.5})

      # # %% plot performance
      # results.plot_performance(PSNR, figsize=(10, 4))
    inference_by_dataset_model_core()
In [ ]:
# inference_by_mult_datasets_models( dataset_list = [dataset_ellipses, dataset], test_data_list = [test_data_ellipses_all, test_data_all]) # Beware of OOM Out of Memory
<dival.datasets.dataset.ObservationGroundTruthPairDataset object at 0x7d5ab0a40c40>
<class 'dival.data.DataPairs'>
<class 'odl.tomo.operators.ray_trafo.RayTransform'>
False
running task 0/6 ...
running task 1/6 ...
running task 2/6 ...
running task 3/6 ...
running task 4/6 ...
running task 5/6 ...
ResultTable(results=
                                  reconstructor  test_data                         measure_values
task_ind sub_task_ind                                                                            
0        0                     FBPReconstructor  test part      mean: {psnr: 21.06, ssim: 0.4381}
1        0                      CGReconstructor  test part      mean: {psnr: 23.25, ssim: 0.5849}
2        0             GaussNewtonReconstructor  test part      mean: {psnr: 24.53, ssim: 0.6224}
3        0               LandweberReconstructor  test part      mean: {psnr: 20.13, ssim: 0.4495}
4        0                    MLEMReconstructor  test part      mean: {psnr: 15.23, ssim: 0.2639}
5        0                    ISTAReconstructor  test part  mean: {psnr: -183.5, ssim: 1.798e-19}
)
<dival.datasets.lodopab_dataset.LoDoPaBDataset object at 0x7d5abd1fa8f0>
<class 'dival.data.DataPairs'>
<class 'odl.tomo.operators.ray_trafo.RayTransform'>
False
running task 0/6 ...
running task 1/6 ...
In [ ]:
# inference_by_mult_datasets_models( dataset_list = [dataset], test_data_list = [test_data_all])
running task 0/1 ...
ResultTable(results=
                           reconstructor  test_data                      measure_values
task_ind sub_task_ind                                                                  
0        0             ISTAReconstructor  test part  mean: {psnr: 10.25, ssim: 0.07306}
)
In [ ]:
print( type(  L_discrete))
L2 = [ L_discrete + (-1) * L_discrete for i in range (5) ]
print (L2 )
<class 'odl.discr.discr_space.DiscretizedSpaceElement'>
[uniform_discr(-7.5, 7.5, 15).element(
    [ 0.,  0.,  0., ...,  0.,  0.,  0.]
), uniform_discr(-7.5, 7.5, 15).element(
    [ 0.,  0.,  0., ...,  0.,  0.,  0.]
), uniform_discr(-7.5, 7.5, 15).element(
    [ 0.,  0.,  0., ...,  0.,  0.,  0.]
), uniform_discr(-7.5, 7.5, 15).element(
    [ 0.,  0.,  0., ...,  0.,  0.,  0.]
), uniform_discr(-7.5, 7.5, 15).element(
    [ 0.,  0.,  0., ...,  0.,  0.,  0.]
)]
In [ ]:
L2 = uniform_discr_element(L2)
(-1) * L_discrete
Out[ ]:
uniform_discr(-7.5, 7.5, 15).element(
    [-0.5488135 , -0.71518937, -0.60276338, ..., -0.56804456, -0.92559664,
     -0.07103606]
)

Experiment - Training CT reconstruction Networks - an Inverse Radon Map model¶

In [17]:
"""
Train IRadonMapReconstructor on 'lodopab'.
"""
from dival.reconstructors.iradonmap_reconstructor import IRadonMapReconstructor
from dival.reference_reconstructors import (
    check_for_params, download_params, get_hyper_params_path)
In [18]:
# del LOG_DIR
# del SAVE_BEST_LEARNED_PARAMS_PATH

LOG_DIR_iradonmap = '../../logs_17Sep/lodopab_iradonmap'
SAVE_BEST_LEARNED_PARAMS_PATH_iradonmap = '../../params_17Sep/lodopab_iradonmap'
In [19]:
iradonmap_reconstructor = IRadonMapReconstructor(
    ray_trafo_lodopab, log_dir=LOG_DIR_iradonmap,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH_iradonmap)

#%% TRAIN - obtain reference hyper parameters
if not check_for_params('iradonmap', 'lodopab', include_learned=True): # learned True to down wts
    download_params('iradonmap', 'lodopab', include_learned=True)
hyper_params_path = get_hyper_params_path('iradonmap', 'lodopab')
iradonmap_reconstructor.load_hyper_params(hyper_params_path)
In [ ]:
print( iradonmap_reconstructor.hyper_params )
{'epochs': 150, 'batch_size': 2, 'lr': 0.01, 'normalize_by_opnorm': False, 'scales': 5, 'skip_channels': 4, 'fully_learned': True, 'use_sigmoid': False}
In [ ]:
iradonmap_reconstructor.hyper_params['epochs'] = 1
# iradonmap_reconstructor.hyper_params['batch_size'] = 2 # OOM when more than 2
print( iradonmap_reconstructor.hyper_params )
{'epochs': 1, 'batch_size': 2, 'lr': 0.01, 'normalize_by_opnorm': False, 'scales': 5, 'skip_channels': 4, 'fully_learned': True, 'use_sigmoid': False}
In [ ]:
#%% TRAINING MODEL
# iradonmap_reconstructor.train(dataset) # make sure 14GB VRAM is available
In [ ]:
"""
To use the trained model for inference without retraining, you need to load the saved parameters. The `IRadonMapReconstructor` saves the best learned parameters during training to the path specified by `save_best_learned_params_path`. You can then load these parameters using the `load_learned_params` method and use the `reconstruct` method on new observations.
"""
Out[ ]:
'\nTo use the trained model for inference without retraining, you need to load the saved parameters. The `IRadonMapReconstructor` saves the best learned parameters during training to the path specified by `save_best_learned_params_path`. You can then load these parameters using the `load_learned_params` method and use the `reconstruct` method on new observations.\n'
In [20]:
# Load the trained parameters if u've already trained the model
iradonmap_reconstructor.load_learned_params(SAVE_BEST_LEARNED_PARAMS_PATH_iradonmap)

# Hyper param load
# iradonmap_reconstructor.load_params(SAVE_BEST_LEARNED_PARAMS_PATH_iradonmap)

# Assuming you have a new observation 'new_obs'
# new_obs = ...

# Reconstruct using the loaded model
# reconstructed_image = reconstructor.reconstruct(new_obs)

# You can then evaluate the reconstructed image as needed
# psnr = PSNR(reconstructed_image, ground_truth_of_new_obs)
# print('PSNR for new observation: {:f}'.format(psnr))
In [ ]:
#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data_10:
    reco = iradonmap_reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs)))

for i in range(10):
    _, ax = plot_images([recos[i], test_data_10.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
    ax[0].set_title('IRadonMapReconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
mean psnr: 30.403244
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Experiment - Training Grand Challenge winner - LearnedPrimalDual Neural Network¶

In [ ]:
"""
Train LearnedPDReconstructor on 'lodopab'.
"""
from dival.reconstructors.learnedpd_reconstructor import LearnedPDReconstructor
from dival.reference_reconstructors import (
    check_for_params, download_params, get_hyper_params_path)
In [ ]:
# pip install git+https://github.com/odlgroup/odl.git # went from odl stable 0.82 to for the night build 1.00dev
In [ ]:
# import odl.contrib.torch.operator as odlt # fix notes below
"""
  issue fix for This comes from odl/contrib/torch/operator.py in your odl==0.8.2 installation.
  That means in this release, the symbol AVOID_UNNECESSARY_COPY was referenced but not actually defined or imported properly.

  if not hasattr(odlt, "AVOID_UNNECESSARY_COPY"):
      # False is the safe option: numpy.astype(copy=False) avoids forced copies.
      # If you find correctness issues, set to True (forces a copy).
      odlt.AVOID_UNNECESSARY_COPY = False
"""
Out[ ]:
'\n  issue fix for This comes from odl/contrib/torch/operator.py in your odl==0.8.2 installation.\n  That means in this release, the symbol AVOID_UNNECESSARY_COPY was referenced but not actually defined or imported properly.\n\n  if not hasattr(odlt, "AVOID_UNNECESSARY_COPY"):\n      # False is the safe option: numpy.astype(copy=False) avoids forced copies.\n      # If you find correctness issues, set to True (forces a copy).\n      odlt.AVOID_UNNECESSARY_COPY = False\n'
In [ ]:
# if not hasattr(odlt, "AVOID_UNNECESSARY_COPY"): # fix notes above
    # False is the safe option: numpy.astype(copy=False) avoids forced copies.
    # If you find correctness issues, set to True (forces a copy).
    # odlt.AVOID_UNNECESSARY_COPY = False
In [ ]:
# bug fix in reconstructor.train(dataset)
# NameError: name 'AVOID_UNNECESSARY_COPY' is raised because the variable AVOID_UNNECESSARY_COPYis referenced insideodl.contrib.torch.operator.OperatorFunction.forwardbut is not defined in the module namespace at runtime. In plain terms: the ODL torch wrapper code expects a module-global boolean namedAVOID_UNNECESSARY_COPYbut it isn’t there (either because of a packaging/version mismatch or a missing import/initialization), so when execution reaches the line that uses it (used as thecopy=argument tonp.astype(...)) Python raises NameError. The ODL torch operator layer code lives in odl.contrib.torch.operator` (example code in the repo).

import odl.contrib.torch.operator as odl_op
# Set to False to allow numpy astype to not copy when possible.
# Set to True if you want to force copies (safer if you see strange in-place behavior).
odl_op.AVOID_UNNECESSARY_COPY = False
In [ ]:
# IMPL = 'astra_cpu' # for drive CPUs

# if running locally
# os.getcwd()
LOG_DIR_learnedpd = '../../logs/lodopab_learnedpd'
SAVE_BEST_LEARNED_PARAMS_PATH_learnedpd = '../../params/lodopab_learnedpd'

# # if running in drive
# # os.getcwd()
# # %cd '/content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising'
# LOG_DIR = '/content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising/experiments/lodopab_learnedpd/logs'
# SAVE_BEST_LEARNED_PARAMS_PATH = '/content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising/experiments/lodopab_learnedpd/params'

# del reconstructor
# if in drive give dir: /content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising/y
learnedpd_reconstructor = LearnedPDReconstructor(
    ray_trafo_lodopab, log_dir=LOG_DIR_learnedpd,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH_learnedpd)

#%% obtain reference hyper parameters if training, else comment out

# if not check_for_params('learnedpd', 'lodopab', include_learned=False): # learned True -> down wts/ pt?
    # download_params('learnedpd', 'lodopab', include_learned=False)
# hyper_params_path = get_hyper_params_path('learnedpd', 'lodopab')
# learnedpd_reconstructor.load_hyper_params(hyper_params_path)
In [ ]:
print( learnedpd_reconstructor.hyper_params )
learnedpd_reconstructor.hyper_params['epochs'] = 1
print( learnedpd_reconstructor.hyper_params )
{'epochs': 10, 'batch_size': 1, 'lr': 0.0001, 'normalize_by_opnorm': True, 'lr_min': 0.0001, 'niter': 10, 'init_fbp': True, 'init_filter_type': 'Hann', 'init_frequency_scaling': 0.7, 'nprimal': 5, 'ndual': 5, 'use_sigmoid': False, 'nlayer': 3, 'internal_ch': 64, 'kernel_size': 3, 'batch_norm': False, 'prelu': True, 'lrelu_coeff': 0.2}
{'epochs': 1, 'batch_size': 1, 'lr': 0.0001, 'normalize_by_opnorm': True, 'lr_min': 0.0001, 'niter': 10, 'init_fbp': True, 'init_filter_type': 'Hann', 'init_frequency_scaling': 0.7, 'nprimal': 5, 'ndual': 5, 'use_sigmoid': False, 'nlayer': 3, 'internal_ch': 64, 'kernel_size': 3, 'batch_norm': False, 'prelu': True, 'lrelu_coeff': 0.2}
In [ ]:
#%% TRAINING MODEL
# learnedpd_reconstructor.train(dataset)
epoch 1: 100%|███████████████████████████| 35820/35820 [5:17:38<00:00,  1.88it/s, phase=train, loss=0.000176, psnr=35.4]
epoch 1: 100%|██████████████████████████| 3522/3522 [12:50<00:00,  4.57it/s, phase=validation, loss=0.000123, psnr=36.5]
Best val psnr: 36.529610

In [ ]:
#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data_10:
    reco = learnedpd_reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs)))

for i in range(3):
    _, ax = plot_images([recos[i], test_data_10.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
    ax[0].set_title('LearnedPDReconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
mean psnr: 37.218559
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# if already trained, load weights

if not check_for_params('learnedpd', 'lodopab'):
    download_params('learnedpd', 'lodopab')
params_path = get_params_path('learnedpd', 'lodopab')
learnedpd_reconstructor.load_params(params_path)
# pt path '/home/hiran/.dival/reference_params/lodopab/lodopab_learnedpd'

Experimenting on training setup¶

  • epoch 1: 100%|███████████████████████████| 35820/35820 [5:17:38<00:00, 1.88it/s, phase=train, loss=0.000176, psnr=35.4]
  • epoch 1: 100%|██████████████████████████| 3522/3522 [12:50<00:00, 4.57it/s, phase=validation, loss=0.000123, psnr=36.5]Best val psnr: 36.529610
In [ ]:
learnedpd_reconstructor.train??
type(learnedpd_reconstructor)
Out[ ]:
dival.reconstructors.learnedpd_reconstructor.LearnedPDReconstructor

Learnedpd train code - original¶

In [ ]:
# Signature: reconstructor.train(dataset)
# File:      /usr/local/lib/python3.12/dist-packages/dival/reconstructors/standard_learned_reconstructor.py

# -*- coding: utf-8 -*-

try:
    import torch
except ImportError:
    raise ImportError('missing PyTorch')

# import os
import copy
from copy import deepcopy
from math import ceil
# import odl
# import numpy as np
# from tqdm import tqdm

from torch.utils.data import DataLoader
try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    TENSORBOARD_AVAILABLE = False
else:
    TENSORBOARD_AVAILABLE = True
from torch.optim.lr_scheduler import CyclicLR, OneCycleLR

from dival.reconstructors import LearnedReconstructor
from dival.util.torch_utility import load_state_dict_convert_data_parallel
from copy import deepcopy

class StandardLearnedReconstructor_demo(LearnedReconstructor):
    """
    Standard learned reconstructor base class.

    Provides a default implementation that only requires subclasses to
    implement :meth:`init_model`.

    By default, the Adam optimizer is used. This can be changed by
    reimplementing :meth:`init_optimizer`.
    Also, a OneCycleLR scheduler is used by default, which can be changed by
    reimplementing :meth:`init_scheduler`.

    The training implementation selects the best model reached after an integer
    number of epochs based on the validation set.

    The hyper parameter ``'normalize_by_opnorm'`` selects whether
    :attr:`op` should be normalized by the operator norm.
    In this case, the inputs to :attr:`model` are divided by the operator norm.

    Attributes
    ----------
    model : :class:`torch.nn.Module` or `None`
        The neural network.
        Must be initialized by the subclass :meth:`init_model` implementation.
    non_normed_op : :class:`odl.operator.Operator`
        The original `op` passed to :meth:`__init__`, regardless of
        ``self.hyper_params['normalize_by_opnorm']``.
        See also :attr:`op`.
    """

    HYPER_PARAMS = deepcopy(LearnedReconstructor.HYPER_PARAMS)
    HYPER_PARAMS.update({
        'epochs': {
            'default': 20,
            'retrain': True
        },
        'batch_size': {
            'default': 64,
            'retrain': True
        },
        'lr': {
            'default': 0.01,
            'retrain': True
        },
        'normalize_by_opnorm': {
            'default': False,
            'retrain': True
        }
    })

    def __init__(self, op, hyper_params=None, num_data_loader_workers=8,
                 use_cuda=True, show_pbar=True, log_dir=None,
                 log_num_validation_samples=0,
                 save_best_learned_params_path=None, torch_manual_seed=1,
                 shuffle='auto', worker_init_fn=None, **kwargs):
        """
        Parameters
        ----------
        op : :class:`odl.operator.Operator`
            Forward operator.
        num_data_loader_workers : int, optional
            Number of parallel workers to use for loading data.
        use_cuda : bool, optional
            Whether to use cuda for the U-Net.
        show_pbar : bool, optional
            Whether to show tqdm progress bars during the epochs.
        log_dir : str, optional
            Tensorboard log directory (name of sub-directory in utils/logs).
            If `None`, no logs are written.
        log_num_valiation_samples : int, optional
            Number of validation images to store in tensorboard logs.
            This option only takes effect if ``log_dir is not None``.
        save_best_learned_params_path : str, optional
            Save best model weights during training under the specified path by
            calling :meth:`save_learned_params`.
        torch_manual_seed : int, optional
            Fixed seed to set by ``torch.manual_seed`` before training.
            The default is `1`. It can be set to `None` or `False` to disable
            the manual seed.
        shuffle : {``'auto'``, ``False``, ``True``}, optional
            Whether to use shuffling when loading data.
            When ``'auto'`` is specified (the default), ``True`` is used iff
            the dataset passed to :meth:`train` supports random access.
        worker_init_fn : callable, optional
            Callable `worker_init_fn` passed to
            :meth:`torch.utils.data.DataLoader.__init__`, which can be used to
            configure the dataset copies for different worker instances
            (cf. `torch's IterableDataset docs <https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset>`_)
        """
        super().__init__(reco_space=op.domain,
                         observation_space=op.range,
                         hyper_params=hyper_params, **kwargs)
        self.non_normed_op = op
        self.num_data_loader_workers = num_data_loader_workers
        self.use_cuda = use_cuda
        self.show_pbar = show_pbar
        self.log_dir = log_dir
        self.log_num_validation_samples = log_num_validation_samples
        self.save_best_learned_params_path = save_best_learned_params_path
        self.torch_manual_seed = torch_manual_seed
        self.shuffle = shuffle
        self.worker_init_fn = worker_init_fn
        self.model = None
        self._optimizer = None
        self._scheduler = None

        self._opnorm = None

        self.device = (torch.device('cuda:0')
                       if self.use_cuda and torch.cuda.is_available() else
                       torch.device('cpu'))

    @property
    def opnorm(self):
        if self._opnorm is None:
            self._opnorm = odl.power_method_opnorm(self.non_normed_op)
        return self._opnorm

    @property
    def op(self):
        """
        :class:`odl.operator.Operator` :
        The forward operator, normalized if
        ``self.hyper_params['normalize_by_opnorm']`` is ``True``.
        """
        if self.normalize_by_opnorm:
            return (1./self.opnorm) * self.non_normed_op
        return self.non_normed_op

    def eval(self, test_data):
        self.model.eval()

        running_psnr = 0.0
        with tqdm(test_data, desc='test ',
                  disable=not self.show_pbar) as pbar:
            for obs, gt in pbar:
                rec = self.reconstruct(obs)
                running_psnr += PSNR(rec, gt)

        return running_psnr / len(test_data)

    def train(self, dataset):
        if self.torch_manual_seed:
            torch.random.manual_seed(self.torch_manual_seed)

        self.init_transform(dataset=dataset) # initiate transformation, see what meth does below

        # create PyTorch datasets # from parameter dataset - 'lodopab' # enables torch meth
        dataset_train = dataset.create_torch_dataset(
            part='train', reshape=((1,) + dataset.space[0].shape,
                                   (1,) + dataset.space[1].shape),
            transform=self._transform)
        # reshaped ( (1, 1000, 513) , (1, 362, 362) ) # added chanel? dim

        dataset_validation = dataset.create_torch_dataset(
            part='validation', reshape=((1,) + dataset.space[0].shape,
                                        (1,) + dataset.space[1].shape))

        # reset model before training
        self.init_model() # grad function = True # trn mode

        criterion = torch.nn.MSELoss() # loss function
        self.init_optimizer(dataset_train=dataset_train)

        # create PyTorch dataloaders
        shuffle = (dataset.supports_random_access() if self.shuffle == 'auto'
                   else self.shuffle)
        data_loaders = {
            'train': DataLoader(
                dataset_train, batch_size=self.batch_size,
                num_workers=self.num_data_loader_workers, shuffle=shuffle,
                pin_memory=True, worker_init_fn=self.worker_init_fn),
            'validation': DataLoader(
                dataset_validation, batch_size=self.batch_size,
                num_workers=self.num_data_loader_workers, shuffle=shuffle,
                pin_memory=True, worker_init_fn=self.worker_init_fn)}

        dataset_sizes = {'train': len(dataset_train),
                         'validation': len(dataset_validation)}

        self.init_scheduler(dataset_train=dataset_train)
        if self._scheduler is not None:
            schedule_every_batch = isinstance(
                self._scheduler, (CyclicLR, OneCycleLR))

        best_model_wts = deepcopy(self.model.state_dict())
        best_psnr = 0

        if self.log_dir is not None:
            if not TENSORBOARD_AVAILABLE:
                raise ImportError(
                    'Missing tensorboard. Please install it or disable '
                    'logging by specifying `log_dir=None`.')
            writer = SummaryWriter(log_dir=self.log_dir, max_queue=0)
            validation_samples = dataset.get_data_pairs(
                'validation', self.log_num_validation_samples)

        self.model.to(self.device)
        self.model.train()

        for epoch in range(self.epochs):
            # Each epoch has a training and validation phase
            for phase in ['train', 'validation']:
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_psnr = 0.0
                running_loss = 0.0
                running_size = 0
                with tqdm(data_loaders[phase],
                          desc='epoch {:d}'.format(epoch + 1),
                          disable=not self.show_pbar) as pbar:
                    for inputs, labels in pbar:
                        if self.normalize_by_opnorm:
                            inputs = (1./self.opnorm) * inputs
                        inputs = inputs.to(self.device)
                        labels = labels.to(self.device)

                        # zero the parameter gradients
                        self._optimizer.zero_grad()

                        # forward
                        # track gradients only if in train phase
                        with torch.set_grad_enabled(phase == 'train'):
                            outputs = self.model(inputs)
                            loss = criterion(outputs, labels)

                            # backward + optimize only if in training phase
                            if phase == 'train':
                                loss.backward()
                                torch.nn.utils.clip_grad_norm_(
                                    self.model.parameters(), max_norm=1)
                                self._optimizer.step()
                                if (self._scheduler is not None and
                                        schedule_every_batch):
                                    self._scheduler.step()

                        for i in range(outputs.shape[0]):
                            labels_ = labels[i, 0].detach().cpu().numpy()
                            outputs_ = outputs[i, 0].detach().cpu().numpy()
                            running_psnr += PSNR(outputs_, labels_)

                        # statistics
                        running_loss += loss.item() * outputs.shape[0]
                        running_size += outputs.shape[0]

                        pbar.set_postfix({'phase': phase,
                                          'loss': running_loss/running_size,
                                          'psnr': running_psnr/running_size})
                        if self.log_dir is not None and phase == 'train':
                            step = (epoch * ceil(dataset_sizes['train']
                                                 / self.batch_size)
                                    + ceil(running_size / self.batch_size))
                            writer.add_scalar(
                                'loss/{}'.format(phase),
                                torch.tensor(running_loss/running_size), step)
                            writer.add_scalar(
                                'psnr/{}'.format(phase),
                                torch.tensor(running_psnr/running_size), step)

                    if (self._scheduler is not None
                            and not schedule_every_batch):
                        self._scheduler.step()

                    epoch_loss = running_loss / dataset_sizes[phase]
                    epoch_psnr = running_psnr / dataset_sizes[phase]

                    if self.log_dir is not None and phase == 'validation':
                        step = (epoch+1) * ceil(dataset_sizes['train']
                                                / self.batch_size)
                        writer.add_scalar('loss/{}'.format(phase),
                                          epoch_loss, step)
                        writer.add_scalar('psnr/{}'.format(phase),
                                          epoch_psnr, step)

                    # deep copy the model (if it is the best one seen so far)
                    if phase == 'validation' and epoch_psnr > best_psnr:
                        best_psnr = epoch_psnr
                        best_model_wts = deepcopy(self.model.state_dict())
                        if self.save_best_learned_params_path is not None:
                            self.save_learned_params(
                                self.save_best_learned_params_path)

                if (phase == 'validation' and self.log_dir is not None and
                        self.log_num_validation_samples > 0):
                    with torch.no_grad():
                        val_images = []
                        for (y, x) in validation_samples:
                            y = torch.from_numpy(
                                np.asarray(y))[None, None].to(self.device)
                            x = torch.from_numpy(
                                np.asarray(x))[None, None].to(self.device)
                            reco = self.model(y)
                            reco -= torch.min(reco)
                            reco /= torch.max(reco)
                            val_images += [reco, x]
                        writer.add_images(
                            'validation_samples', torch.cat(val_images),
                            (epoch + 1) * (ceil(dataset_sizes['train'] /
                                                self.batch_size)),
                            dataformats='NCWH')

        print('Best val psnr: {:4f}'.format(best_psnr))
        self.model.load_state_dict(best_model_wts)

    def init_transform(self, dataset):
        """
        Initialize the transform (:attr:`_transform`) that is applied on each
        training sample, e.g. for data augmentation.
        In the default implementation of :meth:`train`, it is passed to
        :meth:`Dataset.create_torch_dataset` when creating the training (but
        not the validation) torch dataset, which applies the transform to the
        (tuple of) torch tensor(s) right before returning, i.e. after reshaping
        to ``(1,) + orig_shape``.

        The default implementation of this method disables the transform by
        assigning `None`.
        Called in :meth:`train` at the beginning, i.e. before calling
        :meth:`init_model`, :meth:`init_optimizer` and :meth:`init_scheduler`.

        Parameters
        ----------
        dataset : :class:`dival.datasets.dataset.Dataset`
            The dival dataset passed to :meth:`train`.
        """
        self._transform = None

    @property
    def transform(self):
        """
        callable :
        Transform that is applied on each sample, usually set by
        :meth:`init_transform`, which gets called in :meth:`train`.
        """
        return self._transform

    @transform.setter
    def transform(self, value):
        self._transform = value

    def init_model(self):
        """
        Initialize :attr:`model`.
        Called in :meth:`train` after calling :meth:`init_transform`, but
        before calling :meth:`init_optimizer` and :meth:`init_scheduler`.
        """
        raise NotImplementedError

    def init_optimizer(self, dataset_train):
        """
        Initialize the optimizer.
        Called in :meth:`train`, after calling :meth:`init_transform` and
        :meth:`init_model`, but before calling :meth:`init_scheduler`.

        Parameters
        ----------
        dataset_train : :class:`torch.utils.data.Dataset`
            The training (torch) dataset constructed in :meth:`train`.
        """
        self._optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    @property
    def optimizer(self):
        """
        :class:`torch.optim.Optimizer` :
        The optimizer, usually set by :meth:`init_optimizer`, which gets called
        in :meth:`train`.
        """
        return self._optimizer

    @optimizer.setter
    def optimizer(self, value):
        self._optimizer = value

    def init_scheduler(self, dataset_train):
        """
        Initialize the learning rate scheduler.
        Called in :meth:`train`, after calling :meth:`init_transform`,
        :meth:`init_model` and :meth:`init_optimizer`.

        Parameters
        ----------
        dataset_train : :class:`torch.utils.data.Dataset`
            The training (torch) dataset constructed in :meth:`train`.
        """
        self._scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self._optimizer, max_lr=self.lr,
            steps_per_epoch=ceil(len(dataset_train) / self.batch_size),
            epochs=self.epochs)

    @property
    def scheduler(self):
        """
        torch learning rate scheduler :
        The scheduler, usually set by :meth:`init_scheduler`, which gets called
        in :meth:`train`.
        """
        return self._scheduler

    @scheduler.setter
    def scheduler(self, value):
        self._scheduler = value

    def _reconstruct(self, observation):
        self.model.eval()
        with torch.set_grad_enabled(False):
            obs_tensor = torch.from_numpy(
                np.asarray(observation)[None, None])
            if self.normalize_by_opnorm:
                obs_tensor = obs_tensor / self.opnorm
            obs_tensor = obs_tensor.to(self.device)
            reco_tensor = self.model(obs_tensor)
            reconstruction = reco_tensor.cpu().detach().numpy()[0, 0]
        return self.reco_space.element(reconstruction)

    def save_learned_params(self, path):
        path = path if path.endswith('.pt') else path + '.pt'
        path = os.path.abspath(path)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.model.state_dict(), path)

    def load_learned_params(self, path, convert_data_parallel='auto'):
        """Load learned parameters from file.

        Parameters
        ----------
        path : str
            Path at which the learned parameters are stored.
            Implementations may interpret this as a file path or as a directory
            path for multiple files.
            If the implementation expects a file path, it should accept it
            without file ending.
        convert_data_parallel : bool or {``'auto'``, ``'keep'``}, optional
            Whether to automatically convert the model weight names if
            :attr:`model` is a :class:`nn.DataParallel`-model but the stored
            state dict stems from a non-data-parallel model, or vice versa.

                ``'auto'`` or ``True``:
                    Auto-convert weight names, depending on the type of
                    :attr:`model`.
                ``'keep'`` or ``False``:
                    Do not convert weight names.
                    Convert to plain weight names.
        """
        path = path if path.endswith('.pt') else path + '.pt'
        self.init_model()
        map_location = ('cuda:0' if self.use_cuda and torch.cuda.is_available()
                        else 'cpu')
        state_dict = torch.load(path, map_location=map_location)

        if convert_data_parallel == 'auto' or convert_data_parallel == True:
            load_state_dict_convert_data_parallel(self.model, state_dict)
        elif convert_data_parallel == 'keep' or convert_data_parallel == False:
            self.model.load_state_dict(state_dict)
        else:
            raise ValueError("Unknown option '{}' for `convert_data_parallel`"
                             .format(convert_data_parallel))
In [ ]:
# experimenting on data feed to training

# self = reconstructor
# print(self._scheduler)
# dataset.supports_random_access()
# self.shuffle

# print('\n', dataset_train.__getitem__(0) ) # ?? # torch.datasets
# # print('\n',dataset.__getitem__(0) ) # err since lodopab Dtst
# print('\n', dataset_train.dataset.get_sample(0, part='train') )
# print('\n', dataset_train.dataset.get_sample(0) ) # same as above since there is no pt in ds_trn

print('\n', dataset.get_sample(0, part='validation') ) # no need to call .dts since lodopab syntx
print('\n', dataset.get_sample(0) ) # has to call .dataset for torchRandDatast
 (uniform_discr(
    [ 0.        , -0.18384776], [ 3.14159265,  0.18384776],
    (1000, 513),
    dtype='float32'
).element(
    [[ 0.00005112,  0.00005112,  0.0002455 , ...,  0.00001201, -0.00017578,
       0.00007223],
     [ 0.00025775,  0.00009943, -0.00025834, ..., -0.000003  ,  0.0001692 ,
      -0.00014913],
     [ 0.00025469, -0.00004493,  0.00001501, ...,  0.00014489,  0.00022409,
      -0.00001499],
     ..., 
     [-0.00002997,  0.00020882,  0.00022103, ...,  0.00000901, -0.0000897 ,
      -0.00015506],
     [-0.00020828, -0.00010458, -0.00013133, ...,  0.00039643,  0.00008431,
       0.0001692 ],
     [-0.00008374, -0.00019057,  0.00017833, ...,  0.00012668, -0.00019352,
      -0.00032281]]
), uniform_discr(
    [-0.13, -0.13], [ 0.13,  0.13], (362, 362), dtype='float32'
).element(
    [[ 0.00261918,  0.00156398,  0.0019433 , ...,  0.1303666 ,  0.07036828,
       0.03579117],
     [ 0.00025373,  0.00038949,  0.00038206, ...,  0.1641931 ,  0.11035281,
       0.05525199],
     [ 0.00037611,  0.0004624 ,  0.00028693, ...,  0.17303914,  0.1496849 ,
       0.09620444],
     ..., 
     [ 0.00029302,  0.00046841,  0.00105497, ...,  0.22505756,  0.2250221 ,
       0.22316544],
     [ 0.00038185,  0.00038518,  0.00059254, ...,  0.22446391,  0.227184  ,
       0.22562501],
     [ 0.00036999,  0.00026109,  0.00026571, ...,  0.22011115,  0.22439928,
       0.22344765]]
))

 (uniform_discr(
    [ 0.        , -0.18384776], [ 3.14159265,  0.18384776],
    (1000, 513),
    dtype='float32'
).element(
    [[-0.00007181, -0.00004194, -0.00024069, ..., -0.00010458,  0.00001802,
       0.00014489],
     [-0.00011648,  0.00006921, -0.00017874, ..., -0.00011945, -0.000006  ,
       0.00017225],
     [-0.00028182, -0.00003296, -0.00020238, ..., -0.00005091, -0.00009565,
      -0.00003895],
     ..., 
     [-0.00003596,  0.000006  , -0.00023186, ...,  0.00044923, -0.00002398,
       0.00013578],
     [ 0.00003907, -0.00028475, -0.00007479, ..., -0.00040148, -0.00004493,
      -0.00004493],
     [ 0.00003907,  0.00020577,  0.00041504, ...,  0.00008431,  0.0002302 ,
       0.00017529]]
), uniform_discr(
    [-0.13, -0.13], [ 0.13,  0.13], (362, 362), dtype='float32'
).element(
    [[ 0.00412741,  0.00092089,  0.        , ...,  0.00818573,  0.00895679,
       0.00875593],
     [ 0.00207709,  0.00036544,  0.        , ...,  0.01021071,  0.0117337 ,
       0.01011639],
     [ 0.00189564,  0.0005095 ,  0.00305212, ...,  0.00841267,  0.01071482,
       0.0110803 ],
     ..., 
     [ 0.00089869,  0.00204101,  0.0024129 , ...,  0.0210505 ,  0.01509479,
       0.01488117],
     [ 0.        ,  0.00089668,  0.00360478, ...,  0.01826979,  0.01664608,
       0.01685752],
     [ 0.        ,  0.00233886,  0.00510937, ...,  0.01820439,  0.01630987,
       0.01301141]]
))

Experimenting on training setup - contd¶

In [ ]:
LearnedReconstructor??
In [ ]:
# how to shuffle and model torch DataLoaders

shuffle = (dataset.supports_random_access() if learnedpd_reconstructor.shuffle == 'auto'
            else reconstructor.shuffle)
"""
data_loaders = {'train': DataLoader(
        dataset_trainT, batch_size=reconstructor.batch_size,
        num_workers=reconstructor.num_data_loader_workers, shuffle=shuffle,
        pin_memory=True, worker_init_fn=reconstructor.worker_init_fn),'validation': DataLoader(
        dataset_validationT, batch_size=reconstructor.batch_size,
        num_workers=reconstructor.num_data_loader_workers, shuffle=shuffle,
        pin_memory=True, worker_init_fn=reconstructor.worker_init_fn)} # replaced self. with recn.

type(data_loaders['train'])
"""
Out[ ]:
"\ndata_loaders = {'train': DataLoader(\n        dataset_trainT, batch_size=reconstructor.batch_size,\n        num_workers=reconstructor.num_data_loader_workers, shuffle=shuffle,\n        pin_memory=True, worker_init_fn=reconstructor.worker_init_fn),'validation': DataLoader(\n        dataset_validationT, batch_size=reconstructor.batch_size,\n        num_workers=reconstructor.num_data_loader_workers, shuffle=shuffle,\n        pin_memory=True, worker_init_fn=reconstructor.worker_init_fn)} # replaced self. with recn.\n\ntype(data_loaders['train'])\n"

Experiment Inferencing with DeepImagePrior¶

In [ ]:
from dival.reconstructors.dip_ct_reconstructor import (
    DeepImagePriorCTReconstructor)
In [ ]:
TEST_SAMPLE = 0
obs, gt = dataset.get_sample(TEST_SAMPLE, 'test')

def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt],
                        fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('DIP iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()
In [ ]:
diptv_reconstructor = DeepImagePriorCTReconstructor(
    dataset.get_ray_trafo(impl=IMPL),
    callback_func=callback_func, callback_func_interval=4000)

#%% obtain reference hyper parameters
if not check_for_params('diptv', 'lodopab'):
    download_params('diptv', 'lodopab')
params_path = get_params_path('diptv', 'lodopab')
diptv_reconstructor.load_params(params_path)

#%% evaluate
reco = diptv_reconstructor.reconstruct(obs)
psnr = PSNR(reco, gt)

print('psnr: {:f}'.format(psnr))
_, ax = plot_images([reco, gt],
                    fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr))
ax[0].set_title('DeepImagePriorCTReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
DIP:   0%|                                                                                    | 0/17000 [00:00<?, ?it/s]
No description has been provided for this image
DIP:  11%|████████                                                                 | 1883/17000 [01:54<15:16, 16.49it/s]

KeyboardInterrupt

Transition layer needed for a Hybrid Architecture - Build Phase¶

In [21]:
def odl_elem_to_torch(elem, device=None):
    """
    Convert an odl.discr.discr_space.DiscretizedSpaceElement `elem`
    to a torch.Tensor with shape (1, C, H, W) and dtype torch.float32.

    - Uses np.asarray(elem) to obtain a numpy view/copy.
    - Ensures contiguous memory and float32 dtype.
    - If the numpy array is 2D -> treated as (H, W) -> becomes (1, H, W).
    - If the numpy array is HWC (e.g. (H, W, C)) it will be transposed to (C, H, W).
    """
    arr = np.asarray(elem)                  # get numpy view/copy of ODL element
    arr = np.ascontiguousarray(arr)        # ensure contiguous memory

    # make dtype compatible with torch.from_numpy
    if not np.issubdtype(arr.dtype, np.floating):
        arr = arr.astype(np.float32)
    else:
        arr = arr.astype(np.float32, copy=False)

    # normalize shape to (C, H, W)
    if arr.ndim == 2:           # (H, W) -> (1, H, W)
        arr = arr[None, ...]
    elif arr.ndim == 3:
        # if last dim is small (1,3,4) it's probably HWC -> transpose
        if arr.shape[2] in (1, 3, 4):
            arr = arr.transpose(2, 0, 1)  # HWC -> CHW
        # else assume it's already (C, H, W)

    # create tensor and add batch dim -> (1, C, H, W)
    tensor = torch.from_numpy(arr)
    tensor = tensor.unsqueeze(0)  # batch dim

    if device is not None:
        tensor = tensor.to(device)

    return tensor

def odl_to_single(elem):    # shape change (H, W) -> (H, W, 1)
    """
    Convert an odl.discr.discr_space.DiscretizedSpaceElement `elem`
    to a single precision float32 with shape (H, W, 1) and dtype np.float32.

    - Uses np.asarray(elem) to obtain a numpy view/copy.
    - Ensures contiguous memory and float32 dtype.

    - If the numpy array is 2D -> treated as (H, W) -> becomes (1, H, W).
    - If the numpy array is HWC (e.g. (H, W, C)) it will be transposed to (C, H, W).
    """
    arr = np.asarray(elem)                  # get numpy view/copy of ODL element
    arr = np.ascontiguousarray(arr)        # ensure contiguous memory

    # make dtype compatible with torch.from_numpy
    if not np.issubdtype(arr.dtype, np.floating):
        arr = arr.astype(np.float32)
    else:
        arr = arr.astype(np.float32, copy=False)

    # normalize shape to (H, W, C)
    if arr.ndim == 2:           # shape change (H, W) -> (H, W, 1)
        arr = arr[..., None]

    return arr

Implementing Hybrid Architectures - FBP on top of DRUnet¶

In [ ]:
# from Restormer/ Denoising to root '/home/hiran'
if os.getcwd() != '/home/hiran':
  %cd '/home/hiran'
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.
  bkms = self.shell.db.get('bookmarks', {})
[Errno 2] No such file or directory: '# from Restormer/ Denoising to root /home/hiran'
/home/hiran/Restormer/Denoising
Out[ ]:
'/home/hiran/Restormer/Denoising'

Gaussian Denoiser Codebase & Methods, 2021¶

In [ ]:
# git clone utils, model_zoo, models subversions from Kai Zhang repo at https://github.com/cszn/DPIR.git , Associate Professor of the School of Intelligence Science and Technology at Nanjing University.
In [ ]:
# Experimenting with Drunet model with phantom noice input as the model's 2nd dimension

"""
Kai Zhang (cskaizhang@gmail.com)
github: https://github.com/cszn/DPIR
        https://github.com/cszn/IRCNN
        https://github.com/cszn/KAIR
@article{zhang2020plug,
  title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
  author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
  journal={arXiv preprint},
  year={2020}
}
"""
# IMPORTANT

def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------

    noise_level_img = 15                 # set AWGN noise level for noisy image
    noise_level_model = noise_level_img  # set noise level for model
    model_name = 'drunet_gray'           # set denoiser model, 'drunet_gray' | 'drunet_color'
    testset_name = 'set12' # 'bsd68'               # set test set,  'bsd68' | 'cbsd68' | 'set12'
    x8 = False                           # default: False, x8 to boost performance
    show_img = True # False                     # default: False
    border = 0                           # shave boader to calculate PSNR and SSIM

    if 'color' in model_name:
        n_channels = 3                   # 3 for color image
    else:
        n_channels = 1                   # 1 for grayscale image

    model_pool = 'model_zoo'             # fixed
    testsets = 'testsets'                # fixed
    results = 'results'                  # fixed
    task_current = 'dn'                  # 'dn' for denoising
    result_name = testset_name + '_' + task_current + '_' + model_name

    model_path = os.path.join(model_pool, model_name+'.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------

    L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
    E_path = os.path.join(results, result_name)   # E_path, for Estimated images
    utilsImg.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    # ----------------------------------------
    # load model
    # ----------------------------------------

    from models.network_unet import UNetRes as net
    model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
    model_path = './model_zoo/drunet_gray.pth' # added path
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []

    logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model))
    logger.info(L_path)
    L_paths = utilsImg.get_image_paths(L_path)

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------

        img_name, ext = os.path.splitext(os.path.basename(img))
        # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
        img_H = utilsImg.imread_uint(img, n_channels=n_channels)# img_read_w_unsign_int
        img_L = utilsImg.uint2single(img_H) # invoke to unsign_int_w_single_precision

        # Add noise without clipping
        np.random.seed(seed=0)  # for reproducibility
        img_L += np.random.normal(0, noise_level_img/255., img_L.shape) # adding an array of Gaussian noise to img L. Noise params (mean, sd, shape)

        utilsImg.imshow(utilsImg.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None

        img_L = utilsImg.single2tensor4(img_L) # invoke img in single_precision to a tensor_w_4_dimensions # shape (batch_no, channel_no, height, width)?
        img_L = torch.cat((img_L, torch.FloatTensor([noise_level_model/255.]).repeat(1, 1, img_L.shape[2], img_L.shape[3])), dim=1) # a new tensor where img is dimension 0 and noise_tensor_in_same_shape is dim 1       # noise tensor shape match img shape (batch_no, channel_no, height, width)?
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------

        if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0:
            img_E = model(img_L)
        elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0):
            img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)
        elif x8:
            img_E = utils_model.test_mode(model, img_L, mode=3)

        img_E = utilsImg.tensor2uint(img_E)

        # --------------------------------
        # PSNR and SSIM
        # --------------------------------

        if n_channels == 1:
            img_H = img_H.squeeze()
        psnr = utilsImg.calculate_psnr(img_E, img_H, border=border)
        ssim = utilsImg.calculate_ssim(img_E, img_H, border=border)
        test_results['psnr'].append(psnr)
        test_results['ssim'].append(ssim)
        logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))

        # ------------------------------------
        # save results
        # ------------------------------------

        utilsImg.imsave(img_E, os.path.join(E_path, img_name+ext))

    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
    logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))


# if __name__ == '__main__':

#     main()
Out[ ]:
'/home/hiran'
In [ ]:
# expanding the model and experiment:

# ----------------------------------------
# Preparation
# ----------------------------------------

noise_level_img = 15                 # set AWGN noise level for noisy image ##
noise_level_model = noise_level_img  # set noise level for model ##
model_name = 'drunet_gray'           # set denoiser model, 'drunet_gray' | 'drunet_color'
testset_name = 'set12' # 'bsd68'               # set test set,  'bsd68' | 'cbsd68' | 'set12'
x8 = False                           # default: False, x8 to boost performance
show_img = True # False                     # default: False
border = 0                           # shave boader to calculate PSNR and SSIM

if 'color' in model_name:
    n_channels = 3                   # 3 for color image
else:
    n_channels = 1                   # 1 for grayscale image

model_pool = 'model_zoo'             # fixed
testsets = 'testsets'                # fixed ##
results = 'results'                  # fixed
task_current = 'dn'                  # 'dn' for denoising
result_name = testset_name + '_' + task_current + '_' + model_name ##

model_path = os.path.join(model_pool, model_name+'.pth')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

# ----------------------------------------
# L_path, E_path, H_path
# ----------------------------------------

L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
E_path = os.path.join(results, result_name)   # E_path, for Estimated images
utilsImg.mkdir(E_path) # util

logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)

# ----------------------------------------
# load model
# ----------------------------------------

from models.network_unet import UNetRes as net
model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
model_path = './model_zoo/drunet_gray.pth' # param path
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))

test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []

logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model))
logger.info(L_path)
L_paths = utilsImg.get_image_paths(L_path) # util
25-09-17 13:46:46.486 : model_name:drunet_gray, model sigma:15, image sigma:15
25-09-17 13:46:46.489 : testsets/set12
In [ ]:
# iteration over test dataset

for idx, img in enumerate(L_paths):

    # ------------------------------------
    # (1) img_L
    # ------------------------------------

    img_name, ext = os.path.splitext(os.path.basename(img))
    # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
    img_H = utilsImg.imread_uint(img, n_channels=n_channels) # gt img_read_w_unsign_int
    img_L = utilsImg.uint2single(img_H) # invoke unsign_int to single_precision # np.float32

    # Add noise without clipping
    np.random.seed(seed=0)  # for reproducibility
    img_L += np.random.normal(0, noise_level_img/255., img_L.shape) # adding an array of Gaussian noise to img L. Noise params (mean, sd, shape)

    utilsImg.imshow(utilsImg.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None

    img_L = utilsImg.single2tensor4(img_L) # invoke img in single_precision to a tensor_w_4_dimensions # shape (batch_no, channel_no, height, width)?
    img_L = torch.cat((img_L, torch.FloatTensor([noise_level_model/255.]).repeat(1, 1, img_L.shape[2], img_L.shape[3])), dim=1) # a new tensor where img is dimension 0 and noise_tensor_in_same_shape is dim 1       # noise tensor shape match img shape (batch_no, channel_no, height, width)?
    img_L = img_L.to(device)

    # ------------------------------------
    # (2) img_E
    # ------------------------------------

    if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0:
        img_E = model(img_L)
    elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0):
        img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)
    elif x8:
        img_E = utils_model.test_mode(model, img_L, mode=3)

    img_E = utilsImg.tensor2uint(img_E)

    # --------------------------------
    # PSNR and SSIM
    # --------------------------------

    if n_channels == 1:
        img_H = img_H.squeeze() # gt squeeze out channel
    psnr = utilsImg.calculate_psnr(img_E, img_H, border=border)
    ssim = utilsImg.calculate_ssim(img_E, img_H, border=border)
    test_results['psnr'].append(psnr)
    test_results['ssim'].append(ssim)
    logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))

    # ------------------------------------
    # save results
    # ------------------------------------

    utilsImg.imsave(img_E, os.path.join(E_path, img_name+ext))

ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))
Output hidden; open in https://colab.research.google.com to view.

Modifying DRUnet pipeline to fit our purpose and for 2025¶

In [ ]:
# updating DruNet from 2021 libraries and bugfixing and testing the fitness for purpose
In [ ]:
# working on the version of the model where noice dimension in input tensor is calculated - blind

# def main():
# def DruNet_diversion(img_in_single_prec):

"""
# ----------------------------------------------------------------------------------
# In real applications, we should set proper
# - "noise_level_img": from [3, 25], set 3 for clean image, try 15 for very noisy LR images
# - "k" (or "kernel_width"): blur kernel is very important!!!  kernel_width from [0.6, 3.0]
# to get the best performance.
# ----------------------------------------------------------------------------------
"""
##############################################################################

testset_name = 'set3c'               # set test set,  'set5' | 'srbsd68'
noise_level_img = 3                  # set noise level of image, from [3, 25], set 3 for clean image
model_name = 'drunet_gray' # 'drunet_color' # 'ircnn_color'         # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
sf = 1 # 2                               # set scale factor, 1, 2, 3, 4
iter_num = 24                        # set number of iterations, default: 24 for SISR

# --------------------------------
# set blur kernel
# --------------------------------
kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
noise_level_model = noise_level_img/255.  # noise level of model
kernel_width = kernel_width_default_x1234[sf-1]

"""
# set your own kernel width !!!!!!!!!!
"""
# kernel_width = 1.0


k = utils_deblur.fspecial('gaussian', 25, kernel_width)
k = sr.shift_pixel(k, sf)  # sr.shift_pixel modified # shift the kernel
k /= np.sum(k)

##############################################################################

show_img = True # False
utilsImg.surf(k) if show_img else None
x8 = False # True                            # default: False, x8 to boost performance
modelSigma1 = 49                     # set sigma_1, default: 49
modelSigma2 = max(sf, noise_level_model*255.)
classical_degradation = True         # set classical degradation or bicubic degradation

task_current = 'dn' # 'dn' for denoising # 'sr' # 'sr' for super-resolution
n_channels = 1 if 'gray' in model_name else 3  # fixed
model_zoo = 'model_zoo'              # fixed
testsets = 'testsets'                # fixed
results = 'results'                  # fixed
result_name = testset_name + '_realapplications_' + task_current + '_' + model_name
model_path = os.path.join(model_zoo, model_name+'.pth')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

# ----------------------------------------
# L_path, E_path, H_path
# ----------------------------------------
L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
E_path = os.path.join(results, result_name)   # E_path, for Estimated images
utilsImg.mkdir(E_path)

logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)

# ----------------------------------------
# load model
# ----------------------------------------
if 'drunet' in model_name:
    from models.network_unet import UNetRes as net
    model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
elif 'ircnn' in model_name:
    from models.network_dncnn import IRCNN as net
    model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
    model25 = torch.load(model_path)
    former_idx = 0

logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
logger.info('Model path: {:s}'.format(model_path))
logger.info(L_path)
L_paths = utilsImg.get_image_paths(L_path)
#
for idx, img in enumerate(L_paths):

    # --------------------------------
    # (1) get img_L
    # --------------------------------
    logger.info('Model path: {:s} Image: {:s}'.format(model_path, img))
    img_name, ext = os.path.splitext(os.path.basename(img))
    img_L = utilsImg.imread_uint(img, n_channels=n_channels)
    img_L = utilsImg.uint2single(img_L)

    print('img_L.shape: {}'.format(img_L.shape))
    print(img_L[0,0,0])
    img_L = utilsImg.modcrop(img_L, 8)  # modcrop
    print(img_L[0,0,0])
    print('img_L.shape: {}'.format(img_L.shape))
    # match indent of DruNet diversion code block to re-build complete DruNet pipeline
    # rm def, imgL overlap, imgsave, return

    # diverting DruNet to lodopab functionality

    # img_L = utilsImg.modcrop(img_in_single_prec, 8)  # modcrop # uncomment for divert function
    print('img_L.shape: {}'.format(img_L.shape))


    # --------------------------------
    # (2) get rhos and sigmas
    # --------------------------------
    rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
    rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)

    # --------------------------------
    # (3) initialize x, and pre-calculation
    # --------------------------------
    x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC) # scaling up

    if np.ndim(x)==2:
        x = x[..., None]

    if classical_degradation:
        x = sr. shift_pixel(x, sf) # sr.shift_pixel modified due to libraries deprecated
    x = utilsImg.single2tensor4(x).to(device)

    img_L_tensor, k_tensor = utilsImg.single2tensor4(img_L), utilsImg.single2tensor4(np.expand_dims(k, 2)) # note that this is the default size img and k is utils_deblur.fspecial('gaussian', 25, kernel_width)
    [k_tensor, img_L_tensor] = utilsImg.todevice([k_tensor, img_L_tensor], device)
    FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

    # --------------------------------
    # (4) main iterations
    # --------------------------------
    for i in range(iter_num):

        print('Iter: {} / {}'.format(i, iter_num))

        # --------------------------------
        # step 1, FFT
        # --------------------------------
        tau = rhos[i].float().repeat(1, 1, 1, 1)
        x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)

        if 'ircnn' in model_name:
            current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)

            if current_idx != former_idx:
                model.load_state_dict(model25[str(current_idx)], strict=True)
                model.eval()
                for _, v in model.named_parameters():
                    v.requires_grad = False
                model = model.to(device)
            former_idx = current_idx

        # --------------------------------
        # step 2, denoiser
        # --------------------------------
        if x8:
            x = utilsImg.augment_img_tensor4(x, i % 8) # augmenting 8 ways

        if 'drunet' in model_name:
            x = torch.cat((x, sigmas[i].repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
            x = utils_model.test_mode(model, x, mode=2, refield=64, min_size=256, modulo=16)
        elif 'ircnn' in model_name:
            x = model(x)

        if x8:
            if i % 8 == 3 or i % 8 == 5:
                x = utilsImg.augment_img_tensor4(x, 8 - i % 8)
            else:
                x = utilsImg.augment_img_tensor4(x, i % 8)

    # --------------------------------
    # (3) img_E
    # --------------------------------
    img_E = utilsImg.tensor2uint(x)
    utilsImg.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))

    # Diverting the Drunet model output to lodopab translation layer
    # divert = utilsImg.tensor2single(x)
    # return divert

# if __name__ == '__main__':

#     main()
<Figure size 640x480 with 0 Axes>
LogHandlers setup!
25-09-17 14:18:42.594 : model_name:drunet_gray, image sigma:3.000, model sigma:0.012
25-09-17 14:18:42.597 : Model path: model_zoo/drunet_gray.pth
25-09-17 14:18:42.597 : testsets/set3c
25-09-17 14:18:42.598 : Model path: model_zoo/drunet_gray.pth Image: testsets/set3c/butterfly.png
img_L.shape: (256, 256, 1)
0.12549
0.12549
img_L.shape: (256, 256, 1)
img_L.shape: (256, 256, 1)
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
25-09-17 14:18:43.304 : Model path: model_zoo/drunet_gray.pth Image: testsets/set3c/leaves.png
img_L.shape: (256, 256, 1)
0.909804
0.909804
img_L.shape: (256, 256, 1)
img_L.shape: (256, 256, 1)
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
25-09-17 14:18:43.833 : Model path: model_zoo/drunet_gray.pth Image: testsets/set3c/starfish.png
Iter: 23 / 24
img_L.shape: (256, 256, 1)
0.466667
0.466667
img_L.shape: (256, 256, 1)
img_L.shape: (256, 256, 1)
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24

Model Definition and Compilation¶

In [ ]:
def DruNet_diversion(img_in_single_prec):

  """
  # ----------------------------------------------------------------------------------
  # In real applications, you should set proper
  # - "noise_level_img": from [3, 25], set 3 for clean image, try 15 for very noisy LR images
  # - "k" (or "kernel_width"): blur kernel is very important!!!  kernel_width from [0.6, 3.0]
  # to get the best performance.
  # ----------------------------------------------------------------------------------
  """
  ##############################################################################

  testset_name = 'set3c'               # set test set,  'set5' | 'srbsd68'
  noise_level_img = 12 # 8 psnr 23 # 3              # set noise level of image, from [3, 25], set 3 for clean image
  model_name = 'drunet_gray' # 'drunet_color' # 'ircnn_color'         # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
  sf = 1 # 2                               # set scale factor, 1, 2, 3, 4
  iter_num = 24                        # set number of iterations, default: 24 for SISR

  # --------------------------------
  # set blur kernel
  # --------------------------------
  kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
  noise_level_model = noise_level_img/255.  # noise level of model
  kernel_width = kernel_width_default_x1234[sf-1]

  """
  # set your own kernel width !!!!!!!!!!
  """
  # kernel_width = 1.0


  k = utils_deblur.fspecial('gaussian', 25, kernel_width)
  k = sr.shift_pixel(k, sf)  # sr.shift_pixel modified # shift the kernel
  k /= np.sum(k)

  ##############################################################################

  show_img = False # True
  utilsImg.surf(k) if show_img else None
  x8 = False # True                            # default: False, x8 to boost performance
  modelSigma1 = 49                     # set sigma_1, default: 49
  modelSigma2 = max(sf, noise_level_model*255.)
  classical_degradation = True         # set classical degradation or bicubic degradation

  task_current = 'dn' # 'dn' for denoising # 'sr' # 'sr' for super-resolution
  n_channels = 1 if 'gray' in model_name else 3  # fixed
  model_zoo = 'model_zoo'              # fixed
  testsets = 'testsets'                # fixed
  results = 'results'                  # fixed
  result_name = testset_name + '_realapplications_' + task_current + '_' + model_name
  model_path = os.path.join(model_zoo, model_name+'.pth')
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  torch.cuda.empty_cache()

  # ----------------------------------------
  # L_path, E_path, H_path
  # ----------------------------------------
  L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
  E_path = os.path.join(results, result_name)   # E_path, for Estimated images
  utilsImg.mkdir(E_path)

  logger_name = result_name
  utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
  logger = logging.getLogger(logger_name)

  # ----------------------------------------
  # load model
  # ----------------------------------------
  if 'drunet' in model_name:
      from models.network_unet import UNetRes as net
      model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
      model.load_state_dict(torch.load(model_path), strict=True)
      model.eval()
      for _, v in model.named_parameters():
          v.requires_grad = False
      model = model.to(device)
  elif 'ircnn' in model_name:
      from models.network_dncnn import IRCNN as net
      model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
      model25 = torch.load(model_path)
      former_idx = 0

  # logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
  # logger.info('Model path: {:s}'.format(model_path))
  logger.info(L_path)
  L_paths = utilsImg.get_image_paths(L_path)
  #
  # for idx, img in enumerate(L_paths):

  # --------------------------------
  # (1) get img_L
  # --------------------------------

  # logger.info('Model path: {:s} Image: {:s}'.format(model_path, img))
  # img_name, ext = os.path.splitext(os.path.basename(img))
  # img_L = utilsImg.imread_uint(img, n_channels=n_channels)
  # img_L = utilsImg.uint2single(img_L)

  # print('img_L.shape: {}'.format(img_L.shape))
  # print(img_L[0,0,0])
  # img_L = utilsImg.modcrop(img_L, 8)  # modcrop
  # print(img_L[0,0,0])
  # print('img_L.shape: {}'.format(img_L.shape))
  # match indent of DruNet diversion code block to re-build complete DruNet pipeline
  # rm def, imgL overlap, imgsave, return

  # diverting DruNet to lodopab functionality

  img_L = utilsImg.modcrop(img_in_single_prec, 8)  # modcrop
  # print('img_L.shape: {}'.format(img_L.shape))


  # --------------------------------
  # (2) get rhos and sigmas
  # --------------------------------
  rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
  rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)

  # --------------------------------
  # (3) initialize x, and pre-calculation
  # --------------------------------
  x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC) # scaling up

  if np.ndim(x)==2:
      x = x[..., None]

  if classical_degradation:
      x = sr. shift_pixel(x, sf) # sr.shift_pixel modified due to libraries deprecated
  x = utilsImg.single2tensor4(x).to(device)

  img_L_tensor, k_tensor = utilsImg.single2tensor4(img_L), utilsImg.single2tensor4(np.expand_dims(k, 2)) # note that this is the default size img and k is utils_deblur.fspecial('gaussian', 25, kernel_width)
  [k_tensor, img_L_tensor] = utilsImg.todevice([k_tensor, img_L_tensor], device)
  FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

  # --------------------------------
  # (4) main iterations
  # --------------------------------
  for i in range(iter_num):

      print('Iter: {} / {}'.format(i, iter_num))

      # --------------------------------
      # step 1, FFT
      # --------------------------------
      tau = rhos[i].float().repeat(1, 1, 1, 1)
      x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)

      # if 'ircnn' in model_name:
      #     current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)

      #     if current_idx != former_idx:
      #         model.load_state_dict(model25[str(current_idx)], strict=True)
      #         model.eval()
      #         for _, v in model.named_parameters():
      #             v.requires_grad = False
      #         model = model.to(device)
      #     former_idx = current_idx

      # --------------------------------
      # step 2, denoiser
      # --------------------------------
      if x8:
          x = utilsImg.augment_img_tensor4(x, i % 8) # augmenting 8 ways

      if 'drunet' in model_name:
          x = torch.cat((x, sigmas[i].repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
          x = utils_model.test_mode(model, x, mode=2, refield=64, min_size=256, modulo=16)
      elif 'ircnn' in model_name:
          x = model(x)

      if x8:
          if i % 8 == 3 or i % 8 == 5:
              x = utilsImg.augment_img_tensor4(x, 8 - i % 8)
          else:
              x = utilsImg.augment_img_tensor4(x, i % 8)

  # --------------------------------
  # (3) img_E
  # --------------------------------
  # img_E = utilsImg.tensor2uint(x)
  # utilsImg.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))

  # Diverting the Drunet model output to lodopab translation layer
  divert = utilsImg.tensor2single(x) # shape H, W
  # divert = divert[..., None] # shape H, W, 1
  return divert

Apply Gaussian Denoicing model on top of the FBP model¶

In [ ]:
"""
Implement DruNet on 'lodopab'.
"""
LOG_DIR2 = './logs/lodopab_drunet'
SAVE_BEST_LEARNED_PARAMS_PATH2 = './params/lodopab_drunet'
In [ ]:
# load drunet model

from models.network_unet import UNetRes as net
model2 = net(in_nc=2, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
model2_path = './model_zoo/drunet_gray.pth' # param path
model2.load_state_dict(torch.load(model2_path), strict=True)
model2.eval()
for k, v in model2.named_parameters():
    v.requires_grad = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
model2 = model2.to(device)
In [ ]:
#%% evaluate
# del recos2 # if recos2 exists
# del psnrs2 # if recos2 exists

recos2 = []
psnrs2 = []
for obs, gt in test_data_10:
    reco = reconstructor_lodopab.reconstruct(obs)

    # applying the transition layer and Denoicing model on top
    # print(reco[0,0])
    # print(reco.shape)
    reco = odl_to_single(reco) # shape change (H, W) -> (H, W, 1)
    reco2 = DruNet_diversion(reco)  # out shape (H, W, 1)
    # print(reco2[0,0])
    # print(reco2.shape)
    gt = utilsImg.modcrop(gt, 8) # match shape of reco2 modcropped in DruNet

    # resume default pipeline
    recos2.append(reco2)
    psnrs2.append(PSNR(reco2, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs2)))

for i in range(3):
    _, ax = plot_images([recos2[i], test_data_10.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs2[i]))
    ax[0].set_title('Drunet_Reconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
LogHandlers exists!
25-09-17 14:46:14.099 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
25-09-17 14:46:14.557 : testsets/set3c
LogHandlers exists!
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
25-09-17 14:46:15.908 : testsets/set3c
LogHandlers exists!
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
25-09-17 14:46:17.264 : testsets/set3c
LogHandlers exists!
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
LogHandlers exists!
25-09-17 14:46:18.593 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
LogHandlers exists!
25-09-17 14:46:19.945 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
LogHandlers exists!
25-09-17 14:46:21.309 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
LogHandlers exists!
25-09-17 14:46:22.645 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
25-09-17 14:46:23.985 : testsets/set3c
LogHandlers exists!
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
LogHandlers exists!
25-09-17 14:46:25.327 : testsets/set3c
Iter: 0 / 24
Iter: 1 / 24
Iter: 2 / 24
Iter: 3 / 24
Iter: 4 / 24
Iter: 5 / 24
Iter: 6 / 24
Iter: 7 / 24
Iter: 8 / 24
Iter: 9 / 24
Iter: 10 / 24
Iter: 11 / 24
Iter: 12 / 24
Iter: 13 / 24
Iter: 14 / 24
Iter: 15 / 24
Iter: 16 / 24
Iter: 17 / 24
Iter: 18 / 24
Iter: 19 / 24
Iter: 20 / 24
Iter: 21 / 24
Iter: 22 / 24
Iter: 23 / 24
mean psnr: 26.474400
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Saving as a Class - the first Hybrid Model: CNN Gausian Denoising NN on top of FBP Model¶

In [22]:
if os.getcwd() != '/home/hiran':
  %cd '/home/hiran'
/home/hiran
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
In [23]:
from models.network_unet import UNetRes

class hybrid_model_UNetRes(UNetRes):
    def __init__(self, weights, name='not_defined', fbp_model = reconstructor_lodopab, sigma = 8):

        super(hybrid_model_UNetRes, self).__init__()
        self.name = name
        self.weights = weights
        self.fbp_model = fbp_model
        self.sigma = sigma
        model_name = 'drunet_gray'
        sf = 1 # 2                               # set scale factor, 1, 2, 3, 4
        iter_num = 24                        # set number of iterations, default: 24 for SISR

        task_current = 'dn' # 'dn' for denoising # 'sr' # 'sr' for super-resolution
        n_channels = 1 if 'gray' in model_name else 3  # fixed
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        torch.cuda.empty_cache()

        from models.network_unet import UNetRes as net
        model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
        model.load_state_dict(torch.load(self.weights), strict=True)
        model.eval()
        for _, v in model.named_parameters():
            v.requires_grad = False
        self.model = model.to(device)


    def reconstruct(self, obs): # define _reconstruct if want to default to UNetRes reconstruct attribute
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      self.obs = obs
      reco = self.fbp_model.reconstruct(obs)
      reco = odl_to_single(reco) # shape change (H, W) -> (H, W, 1)

      ####

      reco = torch.from_numpy(reco[:,:,0])
      # Pad the input if not_multiple_of 8 # pad height and width
      img_multiple_of = 8
      h,w = reco.shape[0], reco.shape[1]
      H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
      padh = H-h if h%img_multiple_of!=0 else 0
      padw = W-w if w%img_multiple_of!=0 else 0
      reco = F.pad(reco, (0,padw,0,padh), 'constant') # padding from last dim, left-right, top-btm

      ###############

      noise_level_img = self.sigma # 8 psnr 23 # 3              # set noise level of image, from [3, 25], set 3 for clean image
      sf = 1 # 2                               # set scale factor, 1, 2, 3, 4
      iter_num = 24                        # set number of iterations, default: 24 for SISR

      # --------------------------------
      # set blur kernel
      # --------------------------------
      kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
      noise_level_model = noise_level_img/255.  # noise level of model
      kernel_width = kernel_width_default_x1234[sf-1]

      # kernel_width = 1.0

      k = utils_deblur.fspecial('gaussian', 25, kernel_width)
      k = sr.shift_pixel(k, sf)  # sr.shift_pixel modified # shift the kernel
      k /= np.sum(k)

      show_img = False # True
      utilsImg.surf(k) if show_img else None
      x8 = False # True                            # default: False, x8 to boost performance
      modelSigma1 = 49                     # set sigma_1, default: 49
      modelSigma2 = max(sf, noise_level_model*255.)
      classical_degradation = True         # set classical degradation or bicubic degradation

      # --------------------------------
      # (2) get rhos and sigmas
      # --------------------------------
      rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
      rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)

      # --------------------------------
      # (3) initialize x, and pre-calculation
      # --------------------------------
      img_L = reco.numpy()
      x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC) # scaling up

      if np.ndim(x)==2:
          x = x[..., None]

      if classical_degradation:
          x = sr. shift_pixel(x, sf) # sr.shift_pixel modified due to libraries deprecated
      x = utilsImg.single2tensor4(x).to(device)

      img_L_tensor, k_tensor = utilsImg.single2tensor4(img_L[...,None]), utilsImg.single2tensor4(np.expand_dims(k, 2)) # note that this is the default size img and k is utils_deblur.fspecial('gaussian', 25, kernel_width)
      [k_tensor, img_L_tensor] = utilsImg.todevice([k_tensor, img_L_tensor], device)
      FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)

      # --------------------------------
      # (4) main iterations
      # --------------------------------
      for i in range(iter_num):

          # print('Iter: {} / {}'.format(i, iter_num))

          # --------------------------------
          # step 1, FFT
          # --------------------------------
          tau = rhos[i].float().repeat(1, 1, 1, 1)
          x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)

          # if 'ircnn' in model_name:
          #     current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)

          #     if current_idx != former_idx:
          #         model.load_state_dict(model25[str(current_idx)], strict=True)
          #         model.eval()
          #         for _, v in model.named_parameters():
          #             v.requires_grad = False
          #         model = model.to(device)
          #     former_idx = current_idx

          # --------------------------------
          # step 2, denoiser
          # --------------------------------
          if x8:
              x = utilsImg.augment_img_tensor4(x, i % 8) # augmenting 8 ways

          # if 'drunet' in model_name:
          # Correct the concatenation to add sigmas as a single channel
          sigma_tensor = sigmas[i].repeat(1, 1, x.shape[2], x.shape[3]) # shape 1, 1, H, W
          x = torch.cat((x, sigma_tensor), dim=1) # concatenate along channel dimension
          x = utils_model.test_mode(self.model, x, mode=2, refield=64, min_size=256, modulo=16)

      divert = torch.clamp(x, 0, 1) # set floor ceiling for pix vals
      divert = utilsImg.tensor2single(divert) # shape H, W
      divert = divert[:h,:w] # discard RGB channel, Batch dim will suffice for evaluation
      # divert = divert[..., None] # shape H, W, 1
      return uniform_discr_element(divert)

weightsU8 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU8 = 'Hybrid UNet Residual model sigma 8'

hybrid_model_u8 = hybrid_model_UNetRes(weightsU8, nameU8, sigma = 8)
In [ ]:
# testing pipeline

if 'recos2' in locals():
  del recos2
  del psnrs2
recos2 = []
psnrs2 = []

for obs, gt in test_data_10:
    torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
    torch.cuda.empty_cache()
    # reco = reconstructor.reconstruct(obs) # return odl elem H,W with normalized pixel vals
    # hybrid_model_1.reconstruct(reco)

    # locate new model in class directly to projection
    reco2 = hybrid_model_u8.reconstruct(obs)

    # resume default pipeline
    recos2.append(reco2)
    psnrs2.append(PSNR(reco2, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs2)))

for i in range(3):
    _, ax = plot_images([recos2[i], test_data_10.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs2[i]))
    ax[0].set_title('Drunet_Reconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))

Experimenting a Hybrid NN - Transformer/ Real Image Denoising on top of FBP Model¶

Restormer SOTA (CVPR-2022) Real Image Denoising Neural Network¶

In [ ]:
# original work by Syed Waqas Zamir, Aditya Arora
# https://github.com/swz30/Restormer.git

1. Setup¶

  • First, in the Runtime menu -> Change runtime type, make sure to have Hardware Accelerator = GPU
  • Clone repo and install dependencies.
In [ ]:
%cd Restormer/Denoising
if os.getcwd() != '/home/hiran/Restormer/Denoising':
  raise ImportError

# Clone Restormer
# !git clone https://github.com/swz30/Restormer.git
/home/hiran/Restormer/Denoising
/home/hiran/miniconda3/envs/ctrecn3/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]

2. Define Task and Download Pre-trained Models¶

Uncomment the task you would like to perform

In [ ]:
task = 'Real_Denoising'
# task = 'Single_Image_Defocus_Deblurring'
# task = 'Motion_Deblurring'
# task = 'Deraining'

# Download the pre-trained models
# if task is 'Real_Denoising':
#   !wget https://github.com/swz30/Restormer/releases/download/v1.0/real_denoising.pth -P Denoising/pretrained_models
if task is 'Single_Image_Defocus_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/single_image_defocus_deblurring.pth -P Defocus_Deblurring/pretrained_models
if task is 'Motion_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth -P Motion_Deblurring/pretrained_models
if task is 'Deraining':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth -P Deraining/pretrained_models
<>:9: SyntaxWarning: "is" with a literal. Did you mean "=="?
<>:11: SyntaxWarning: "is" with a literal. Did you mean "=="?
<>:13: SyntaxWarning: "is" with a literal. Did you mean "=="?
<>:9: SyntaxWarning: "is" with a literal. Did you mean "=="?
<>:11: SyntaxWarning: "is" with a literal. Did you mean "=="?
<>:13: SyntaxWarning: "is" with a literal. Did you mean "=="?
/tmp/ipykernel_383/4256543054.py:9: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if task is 'Single_Image_Defocus_Deblurring':
/tmp/ipykernel_383/4256543054.py:11: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if task is 'Motion_Deblurring':
/tmp/ipykernel_383/4256543054.py:13: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if task is 'Deraining':

3. Upload Images¶

Either download the sample images or upload your own images

In [ ]:
# # from google.colab import files

# Download sample images
# !rm -r demo/*
# !wget https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip -P demo
# shutil.unpack_archive('demo/sample_images.zip', 'demo/')
# os.remove('demo/sample_images.zip')

# OR Uncomment the following block if you would like to upload your own images.

# !rm -r demo/*
# input_dir = 'demo/sample_images/'+task+'/degraded'
# os.makedirs(input_dir, exist_ok=True)
# uploaded = files.upload()
# for filename in uploaded.keys():
#   input_path = os.path.join(input_dir, filename)
#   shutil.move(filename, input_path)
--2025-09-07 09:30:08--  https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip
Resolving github.com (github.com)... 20.26.156.215
Connecting to github.com (github.com)|20.26.156.215|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://release-assets.githubusercontent.com/github-production-release-asset/418793252/ac90e9f3-ee13-4c5f-b4ee-5e057f2b7c43?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-09-07T09%3A09%3A24Z&rscd=attachment%3B+filename%3Dsample_images.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-09-07T08%3A09%3A22Z&ske=2025-09-07T09%3A09%3A24Z&sks=b&skv=2018-11-09&sig=7Jte42ikcR08G82jTxB%2F9R%2FYJf9nHYHVpaP2rDQTQWQ%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc1NzIzNDEwNiwibmJmIjoxNzU3MjMzODA2LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.HSSLQ91EPGCSMi8nsvJKdfOdTPA1Hp1fO1LrGrKOcxM&response-content-disposition=attachment%3B%20filename%3Dsample_images.zip&response-content-type=application%2Foctet-stream [following]
--2025-09-07 09:30:08--  https://release-assets.githubusercontent.com/github-production-release-asset/418793252/ac90e9f3-ee13-4c5f-b4ee-5e057f2b7c43?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-09-07T09%3A09%3A24Z&rscd=attachment%3B+filename%3Dsample_images.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-09-07T08%3A09%3A22Z&ske=2025-09-07T09%3A09%3A24Z&sks=b&skv=2018-11-09&sig=7Jte42ikcR08G82jTxB%2F9R%2FYJf9nHYHVpaP2rDQTQWQ%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc1NzIzNDEwNiwibmJmIjoxNzU3MjMzODA2LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.HSSLQ91EPGCSMi8nsvJKdfOdTPA1Hp1fO1LrGrKOcxM&response-content-disposition=attachment%3B%20filename%3Dsample_images.zip&response-content-type=application%2Foctet-stream
Resolving release-assets.githubusercontent.com (release-assets.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to release-assets.githubusercontent.com (release-assets.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4928696 (4.7M) [application/octet-stream]
Saving to: ‘demo/sample_images.zip’

sample_images.zip   100%[===================>]   4.70M  17.6MB/s    in 0.3s    

2025-09-07 09:30:09 (17.6 MB/s) - ‘demo/sample_images.zip’ saved [4928696/4928696]

4. Prepare Model and Load Checkpoint¶

In [ ]:
def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising': # this attempt
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    return weights, parameters


# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)

load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)
model.cuda()

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval() # disable dropout, batch norm by changing the mode of the model

5. Inference - stock photos¶

In [ ]:
input_dir = 'demo/sample_images/'+task+'/degraded'
out_dir = 'demo/sample_images/'+task+'/restored'
os.makedirs(out_dir, exist_ok=True)
extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']
files = natsorted(glob(os.path.join(input_dir, '*')))

img_multiple_of = 8

print(f"\n ==> Running {task} with weights {weights}\n ")
with torch.no_grad(): # save memory by not calc gradient
  for filepath in tqdm(files):
      # print(file_)
      torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
      torch.cuda.empty_cache()
      img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) # shape H,W,C #3darr
      input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).cuda() # png to uint8 tensor to flt/12 to shape is x,C,H,W

      # Pad the input if not_multiple_of 8 # pad height and width
      h,w = input_.shape[2], input_.shape[3]
      H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
      padh = H-h if h%img_multiple_of!=0 else 0
      padw = W-w if w%img_multiple_of!=0 else 0
      input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

      # input shape torch.Size([1, 3, 400, 496]) # pixels normalized 0,1
      restored_ = model(input_)
      # out shape torch.Size([1, 3, 400, 496]) # pixels normalized 0,1
      restored = torch.clamp(restored_, 0, 1) # set floor ceiling for pix vals

      # Unpad the output
      restored = restored[:,:,:h,:w]

      restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() # change shape to B, H, W, C then move to CPU since tensor to numpy later, conf no gradients hitched
      restored = img_as_ubyte(restored[0]) # select the first img from the batch of one, scale pixels to 0,255 and then array to uint8

      filename = os.path.split(filepath)[-1]
      cv2.imwrite(os.path.join(out_dir, filename),cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
In [ ]:
print(img.shape)
print(type(img))
print((torch.from_numpy(img).float()).shape)
print(torch.from_numpy(img)[0,0,0])
print(type(torch.from_numpy(img).float()[0,0,0]))
print(torch.from_numpy(img).float()[0,0,0].div(255.))
print(input_.shape)
print(input_.unsqueeze(0).shape)
print(restored_[0,0,0,0])
(400, 496, 3)
<class 'numpy.ndarray'>
torch.Size([400, 496, 3])
tensor(55, dtype=torch.uint8)
<class 'torch.Tensor'>
tensor(0.2157)
torch.Size([1, 3, 400, 496])
torch.Size([1, 1, 3, 400, 496])
tensor(0.1440, device='cuda:0')

6. Visualize Results - stock photos¶

In [ ]:
import matplotlib.pyplot as plt
inp_filenames = natsorted(glob(os.path.join(input_dir, '*')))
out_filenames = natsorted(glob(os.path.join(out_dir, '*')))

## Will display only first 5 images
num_display_images = 5
if len(inp_filenames)>num_display_images:
  inp_filenames = inp_filenames[:num_display_images]
  out_filenames = out_filenames[:num_display_images]

print(f"Results: {task}")
for inp_file, out_file in zip(inp_filenames, out_filenames):
  degraded = cv2.cvtColor(cv2.imread(inp_file), cv2.COLOR_BGR2RGB)
  restored = cv2.cvtColor(cv2.imread(out_file), cv2.COLOR_BGR2RGB)
  ## Display Images
  fig, axes = plt.subplots(nrows=1, ncols=2)
  dpi = fig.get_dpi()
  fig.set_size_inches(900/ dpi, 448 / dpi)
  plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
  axes[0].axis('off')
  axes[0].imshow(degraded)
  axes[1].axis('off')
  axes[1].imshow(restored)
  plt.show()
Results: Real_Denoising
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

7. Download Results - stock photos¶

In [ ]:
from google.colab import files
zip_filename = f"Restormer_{task}.zip"
os.system(f"zip -r {zip_filename} demo/sample_images/{task}")
files.download(zip_filename)

Apply Transformer model - Real noise Restormer on top of the FBP model¶

In [ ]:
if os.getcwd() != '/home/hiran/Restormer':
  raise ImportError
# %cd /Restormer

"""
Implement 'Restormer' - model on Real img Noise on dataset 'lodopab'.
"""
LOG_DIR2 = './logs/lodopab_restormer_real'
SAVE_BEST_LEARNED_PARAMS_PATH2 = './params/lodopab_restormer_real'
The LoDoPaB-CT dataset could not be found under the configured path '../../y'.
Do you want to download it now? (y: download, n: input other path)
n
Path to LoDoPaB dataset:
../y
updated configuration in '/home/hiran/.dival/config.json':
'lodopab_dataset/data_path' = ../y
In [ ]:
# ----------------------------
# load FBP model
# ----------------------------

# from ctexample section
reconstructor = FBPReconstructor(ray_trafo, hyper_params={
    'filter_type': 'Hann',
    'frequency_scaling': 0.8})
In [ ]:
#%% evaluate
img_multiple_of = 8

try:
  del recos2
  del psnrs2
except:
  print('no recos2')
finally:
  recos2 = []
  psnrs2 = []

with torch.no_grad(): # save memory by not calc gradient
  for obs, gt in test_data:
      # torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
      # torch.cuda.empty_cache()
      reco = reconstructor.reconstruct(obs) # return odl elem H,W

      # testing black imagese from Restormer -start

      # applying the transition layer and Denoicing model on top
      # print(reco[0,0])
      # print(reco.shape)
      reco = odl_to_single(reco) # return 3d np arr H,W,C

      # print(reco[0,0])
      # print(reco.shape)
      # print(reco.size)
      reco_ = reco[:,:,0]
      # print(reco_[0,0])
      # print(reco_.shape)
      # print(reco_.size)

      reco1 = np.stack( (reco_, np.zeros_like(reco_), np.zeros_like(reco_)), axis=-1) # shape 3C,H,W
      # print(reco1[0,0])
      # print(reco1.shape)
      # print(reco1.size)

      input_ = torch.from_numpy(reco1).permute(2,0,1).unsqueeze(0).cuda() # shape is 1B,3C,H,W

      # print(input_[0,0])
      # print(input_.shape)

      # Pad the input if not_multiple_of 8 # pad height and width
      h,w = input_.shape[2], input_.shape[3]
      H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
      padh = H-h if h%img_multiple_of!=0 else 0
      padw = W-w if w%img_multiple_of!=0 else 0
      input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

      # inpput shape torch.Size([1, 3, 368, 368]) # pixels normalized 0,1
      restored_ = model(input_)
      # out shape torch.Size([1, 3, 368, 368]) # pixels normalized 0,1
      restored = torch.clamp(restored_, 0, 1) # set floor ceiling for pix vals

      # Unpad the output
      # print(restored[0,0])
      # print(restored.numel())
      # print(restored.shape)
      restored2 = restored[0,0,:h,:w] # discard RGB channel, Batch dim will suffice for evaluation
      # print(restored2[0])
      # print(restored2.numel())
      # print(restored2.shape)

      reco = restored2.cpu().detach().numpy() # change shape from B,C,H,W to B, H, W, C then move to CPU since tensor to numpy later, conf no gradients still attached
      # print(reco2.shape)
      # print(reco2.size)
      # reco = img_as_ubyte(reco) # scale pixels to 0,255 and then array to uint8
      # print(reco2.shape)
      # print(reco2.size)

      # testing black imagese from Restormer -end

      # resume default pipeline reco2 can be H,W,1
      recos2.append(reco)
      # gt = utilsImg.modcrop(gt, 8) # match shape of reco2 modcropped in DruNet
      psnrs2.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs2)))

for i in range(10):
    _, ax = plot_images([recos2[i], test_data.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs2[i]))
    ax[0].set_title('Restormer_Reconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
mean psnr: 30.831097
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
plt.imshow(recos2[0])
plt.show()
plt.imshow(test_data_2.ground_truth[0])
plt.show()

# for i in range(3):
#   plot_images(test_data_2.ground_truth[i])
No description has been provided for this image
No description has been provided for this image
In [ ]:
a = plot_images([test_data_2.ground_truth[2]], fig_size=(10, 4))
b = plot_images([recos2[2]], fig_size=(10, 4))
No description has been provided for this image
No description has been provided for this image
In [ ]:
# Real world Denoising doesn't work well on CT reconstructions.
# Let's build up on the transformer restormer on gaussian noise.

Experimenting a Hybrid NN -> Transformer/ Gaussian Denoising NN on top of FBP Model¶

Modifying the Gaussian Denoiser¶

  • Restormer: Efficient Transformer for High-Resolution Image Restoration
  • Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
  • https://arxiv.org/abs/2111.09881
In [ ]:
# if running in local setup
%cd Restormer/
if os.getcwd() != '/home/hiran/Restormer':
  raise ImportError

%cd Denoising/
if os.getcwd() != '/home/hiran/Restormer/Denoising':
  raise ImportError
import utils
if utils.__file__ != '/home/hiran/Restormer/Denoising/utils.py':
  raise ImportError

# !ls
# os.getcwd() # confirm cwd is Denoiser
# sys.path.insert(0,os.getcwd())
# utils.load_gray_img? # confirm functions are imported
# sys.modules.pop("utils",None) # if incorrect utils is loaded
# sys.path.append('basicsr')

# --------------------------------
# Remove argparse and define variables directly
# parser = argparse.ArgumentParser(description='Gasussian Grayscale Denoising using Restormer')
# parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
# parser.add_argument('--result_dir', default='./results/Gaussian_Gray_Denoising/', type=str, help='Directory for results')
# parser.add_argument('--weights', default='./pretrained_models/gaussian_gray_denoising', type=str, help='Path to weights')
# parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
# parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
# args = parser.parse_args()
# --------------------------------


# Define the necessary variables directly
input_dir = './Datasets/test/'  # Update with actual path if needed
result_dir = './results/Gaussian_Gray_Denoising/' # Update with actual path if needed
weights_path = './pretrained_models/gaussian_gray_denoising' # Update with actual path if needed
model_type = 'blind' # or 'non_blind'
sigmas_str = '15,25,50' # Define sigma values as a string

####### Load yaml #######
if model_type == 'blind':
    yaml_file = 'Options/GaussianGrayDenoising_Restormer.yml'
else:
    # Assuming sigma for non_blind is the first one if multiple are given
    yaml_file = f'Options/GaussianGrayDenoising_RestormerSigma{sigmas_str.split(",")[0]}.yml'

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)

s = x['network_g'].pop('type')
##########################

sigmas = np.int_(sigmas_str.split(','))

factor = 8

datasets = ['Set12'] # ['Set12', 'BSD68', 'Urban100']

for sigma_test in sigmas:
    print("Compute results for noise level",sigma_test)
    model_restoration = Restormer(**x['network_g'])
    if model_type == 'blind':
        weights = weights_path+'_blind.pth'
    else:
        weights = weights_path + '_sigma' + str(sigma_test) +'.pth'
    checkpoint = torch.load(weights)
    model_restoration.load_state_dict(checkpoint['params'])

    print("===>Testing using weights: ",weights)
    print("------------------------------------------------")
    model_restoration.cuda()
    model_restoration = nn.DataParallel(model_restoration)
    model_restoration.eval()

    for dataset in datasets:
        inp_dir = os.path.join(input_dir, dataset)
        files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif')))
        result_dir_tmp = os.path.join(result_dir, model_type, dataset, str(sigma_test))
        os.makedirs(result_dir_tmp, exist_ok=True)

        with torch.no_grad():
            for file_ in tqdm(files):
                torch.cuda.ipc_collect()
                torch.cuda.empty_cache()
                img = np.float32(utils.load_gray_img(file_))/255. # normalised pixel vals in single precision # shape C?,H,W

                np.random.seed(seed=0)  # for reproducibility
                img += np.random.normal(0, sigma_test/255., img.shape)

                img = torch.from_numpy(img).permute(2,0,1)
                input_ = img.unsqueeze(0).cuda() # shape B:1?, C:1?,H,W

                # Padding in case images are not multiples of 8
                h,w = input_.shape[2], input_.shape[3]
                H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
                padh = H-h if h%factor!=0 else 0
                padw = W-w if w%factor!=0 else 0
                input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

                restored = model_restoration(input_) # output normalised pixel vals in shape BCHW

                # Unpad images to original dimensions
                restored = restored[:,:,:h,:w]

                restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() # clamped normalised pixel vals in 3d array with shape H,W,C:1

                save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1])
                utils.save_gray_img(save_file, img_as_ubyte(restored))
Out[ ]:
'/content'

Apply Transformer NN - Blind Gausian Restormer on top of the FBP model¶

In [ ]:
if os.getcwd() != '/home/hiran/Restormer/Denoising':
  raise ImportError
# %cd Denoising
"""
Implement  Gaussian Restormer on 'lodopab'.
"""
IMPL = 'astra_cuda'
LOG_DIR2 = '../../logs/lodopab_restormer_Gaussian'
SAVE_BEST_LEARNED_PARAMS_PATH2 = '../../params/lodopab_restormer_Gaussian'
In [ ]:
# ----------------------------
# load FBP model
# ----------------------------

# from ctexample section
reconstructor = FBPReconstructor(ray_trafo, hyper_params={
    'filter_type': 'Hann',
    'frequency_scaling': 0.8})
In [ ]:
#%% evaluate
img_multiple_of = 8

try:
  del recos2
  del psnrs2
except:
  print('no recos2')
finally:
  recos2 = []
  psnrs2 = []

with torch.no_grad(): # save memory by not calc gradient
  for obs, gt in test_data:
      torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
      torch.cuda.empty_cache()

      reco = reconstructor.reconstruct(obs) # return odl elem H,W with normalized pixel vals

      # testing black imagese from Restormer -start

      # applying the transition layer and Denoicing model on top
      reco = odl_to_single(reco) # return 3d np arr H,W,C:1
      # print(reco.mean())
      # print(reco.min())
      # print(reco.max())
      # print(reco.std())

      # reco_ = reco[:,:,0] # shape H,W
      # print(reco_[0,0])
      # print(reco_.shape)
      # print(reco_.size)

      # reco1 = np.stack( (reco_, np.zeros_like(reco_), np.zeros_like(reco_)), axis=-1)
      # print(reco1[0,0])
      # print(reco1.shape)
      # print(reco1.size)

      input_ = torch.from_numpy(reco).permute(2,0,1).unsqueeze(0).cuda() # shape is 1B,3or1C,H,W

      # Pad the input if not_multiple_of 8 # pad height and width
      h,w = input_.shape[2], input_.shape[3]
      H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
      padh = H-h if h%img_multiple_of!=0 else 0
      padw = W-w if w%img_multiple_of!=0 else 0
      input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

      # inpput shape torch.Size([1B, 3or1C, 368, 368]) # pixels normalized 0,1
      restored_ = model_restoration(input_)
      # out shape torch.Size([1, 3, 368, 368]) # pixels normalized 0,1
      # print('restored_', restored_.mean())
      # print('restored_', restored_.min())
      # print('restored_', restored_.max())
      # print('restored_', restored_.std())

      restored_ = torch.clamp(restored_, 0, 1) # set floor ceiling for pix vals

      # Unpad the output
      # print(restored[0,0])
      # print(restored.numel())
      # print(restored.shape)
      restored_ = restored_[0,0,:h,:w] # discard RGB channel, Batch dim will suffice for evaluation
      # print(restored2[0])
      # print(restored2.numel())
      # print(restored2.shape)

      reco2 = restored_.cpu().detach().numpy() # change shape from B,C,H,W to B, H, W, C then move to CPU since tensor to numpy later, conf no gradients still attached
      # print('reco', reco.mean())
      # print('reco', reco.min())
      # print('reco', reco.max())
      # print('reco', reco.std(),'\n\n')
      # print(reco2.shape)
      # print(reco2.size)
      # reco = img_as_ubyte(reco) # scale pixels to 0,255 and then array to uint8 # not needed
      # print(reco2.shape)
      # print(reco2.size)

      # testing black imagese from Restormer -end

      # resume default pipeline reco2 can be H,W,1
      recos2.append(reco2)
      # gt = utilsImg.modcrop(gt, 8) # match shape of reco2 modcropped in DruNet
      psnrs2.append(PSNR(reco2, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs2)))

for i in range(10):
    _, ax = plot_images([recos2[i], test_data.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs2[i]))
    ax[0].set_title('Restormer_Reconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
mean psnr: 31.231878
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
plt.imshow(recos2[0])
plt.show()
plt.imshow(test_data_2.ground_truth[0])
plt.show()

# for i in range(3):
#   plot_images(test_data_2.ground_truth[i])
No description has been provided for this image
No description has been provided for this image
In [ ]:
print(_)
[<matplotlib.image.AxesImage object at 0x75c32a0321a0>
 <matplotlib.image.AxesImage object at 0x75c32a032200>]
In [ ]:
a = plot_images([test_data_2.ground_truth[2]], fig_size=(10, 4))
b = plot_images([recos2[2]], fig_size=(10, 4))
No description has been provided for this image
No description has been provided for this image
In [ ]:
# Display the input to the model
# print("Input to model:")
# print(input_)
# print(input_.shape)
# print(input_.min(), input_.max())

# Display the output of the model before clamping
# print("\nOutput of model before clamping:")
# print(restored_)
# print(restored_.shape)
# print(restored_.min(), restored_.max())

# Display the model parameters
# print("\nModel parameters:")
for name, param in model_restoration.named_parameters():
    if param.requires_grad:
        # print(name, param.data)

Building a Novel Hybrid -> Training the Transformer NN on 70K Images dataset¶

Images from

  • Div2K
  • WaterlooED
  • Flickr2K
  • BSD400
In [ ]:
# os.getcwd()
# %cd Restormer/Denoising

import datetime
import logging
import math
import time
# import os # Import the os module
from os import path as osp

from torch.utils import data as data
from torchvision.transforms.functional import normalize
# %cd Denoising/
# sys.path.append('/home/hiran/Restormer')

# direct basicsr if in ./Denoising
from basicsr.data import create_dataloader, create_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import create_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
                           get_root_logger, get_time_str, init_tb_logger,
                           init_wandb_logger, make_exp_dirs, mkdir_and_rename,
                           set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse

from basicsr.data.data_util import (paired_paths_from_folder,paired_DP_paths_from_folder,paired_paths_from_lmdb,paired_paths_from_meta_info_file,paths_from_lmdb)
from basicsr.data.transforms import augment, paired_random_crop, paired_random_crop_DP,random_augmentation
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, padding_DP,imfrombytesDP,scandir
import importlib
# importlib.reload(basicr.utils.scandir) # not a module, error
In [ ]:
# how to run train.sh for single gpu in bash
# chmod +x train_single_gpu.sh
# ./train_single_gpu.sh path/to/your_config.yml
In [ ]:
# !pip install gdown
# import gdown

# !python download_data.py --data train-test --noise gaussian

# shutil.unpack_archive('Datasets/Downloads/Flickr2K.zip', 'Datasets/Downloads')

# shutil.unpack_archive('Datasets/Downloads/DIV2K.zip', 'Datasets/Downloads')
# os.remove('Datasets/Downloads/DIV2K.zip')
# os.remove('Datasets/Downloads/Flickr2K.zip')

# !python generate_patches_dfwb.py
# shutil.rmtree('Datasets/Downloads') # REMEMBER to del and save 30GB+ # deleted

# gaussian_test = '1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0'   ## https://drive.google.com/file/d/1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0/view?usp=sharing

# print('Gaussian Denoising Testing Data!')
# gdown.download(id=gaussian_test, output='Datasets/test.zip', quiet=False)
# os.system(f'gdrive download {gaussian_test} --path Datasets/')
# print('Extracting Data...')
# shutil.unpack_archive('Datasets/test.zip', 'Datasets')
# os.remove('Datasets/test.zip')

Model Parameters¶

In [ ]:
# this is a settings file, no use in running in plain colab. for hyper parameter tuning in YML
# sync changes from colab to YML periodically at ./Options/GaussianGrayDenoising_Restormer.YML

# general settings
name: GaussianGrayDenoising_Restormer
model_type: ImageCleanModel
scale: 1
num_gpu: 1 # 8  # set num_gpu: 0 for cpu mode
manual_seed: 100

# dataset and data loader settings
datasets:
  train:
    phase: train # added since, missing param for def train dataloader
    name: TrainSet
    type: Dataset_GaussianDenoising
    sigma_type: random
    sigma_range: [0,50]
    in_ch: 1    ## Grayscale image
    dataroot_gt: ./Datasets/train/DFWB
    dataroot_lq: none
    geometric_augs: true

    filename_tmpl: '{}'
    io_backend:
      type: disk

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 4 # 8
    batch_size_per_gpu: 2 # 8

    ### -------------Progressive training--------------------------
    mini_batch_sizes: [2,1,1,1,1,1] # [8,5,4,2,1,1]             # Batch size per gpu
    iters: [92000,64000,48000,36000,36000,24000]
    gt_size: 320 # 384   # Max patch size for progressive training
    gt_sizes: [128,160,192,256,320] # [128,160,192,256,320,384]  # Patch sizes for progressive training.
    ### ------------------------------------------------------------

    ### ------- Training on single fixed-patch size 128x128---------
    # mini_batch_sizes: [8]
    # iters: [300000]
    # gt_size: 128
    # gt_sizes: [128]
    ### ------------------------------------------------------------

    dataset_enlarge_ratio: 1
    prefetch_mode: ~

  val:
    phase: val # added since, missing param for def train dataloader
    name: ValSet
    type: Dataset_GaussianDenoising
    sigma_test: 25
    in_ch: 1  ## Grayscale image
    dataroot_gt: ./Datasets/test/BSD68
    dataroot_lq: none
    io_backend:
      type: disk

# network structures
network_g:
  type: Restormer
  inp_channels: 1
  out_channels: 1
  dim: 48
  num_blocks: [4,6,6,8]
  num_refinement_blocks: 4
  heads: [1,2,4,8]
  ffn_expansion_factor: 2.66
  bias: False
  LayerNorm_type: BiasFree
  dual_pixel_task: False


# path
path:
  pretrain_network_g: ~
  strict_load_g: true
  resume_state: ~

# training settings
train:
  total_iter: 300000
  warmup_iter: -1 # no warm up
  use_grad_clip: true

  # Split 300k iterations into two cycles.
  # 1st cycle: fixed 3e-4 LR for 92k iters.
  # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
  scheduler:
    type: CosineAnnealingRestartCyclicLR
    periods: [92000, 208000]
    restart_weights: [1,1]
    eta_mins: [0.0003,0.000001]

  mixing_augs:
    mixup: true
    mixup_beta: 1.2
    use_identity: true

  optim_g:
    type: AdamW
    lr: !!float 3e-4
    weight_decay: !!float 1e-4
    betas: [0.9, 0.999]

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1
    reduction: mean

# validation settings
val:
  window_size: 8
  val_freq: !!float 4e3
  save_img: false
  rgb2bgr: true
  use_image: false
  max_minibatch: 8

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 0
      test_y_channel: false

# logging settings
logger:
  print_freq: 1000
  save_checkpoint_freq: !!float 4e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500
In [ ]:
def parse_(opt_path: str, is_train: bool = True, launcher='none', local_rank=0) -> dict:
    """
    Read the YAML at opt_path and return a dict-like `opt`.
    is intentionally simple: it loads YAML via yaml.safe_load and returns the dict.
    can extend this to validate required fields, set defaults, or convert types.
    """
    p = Path(opt_path)
    if not p.exists():
        raise FileNotFoundError(f"YAML options file not found: {opt_path}")
    with p.open('r') as f:
        opt = yaml.safe_load(f)
    # ensure we have a dict
    if not isinstance(opt, dict):
        raise ValueError("Parsed YAML is not a mapping (dict).")
    # attach the is_train flag for convenience
    # opt['_is_train'] = bool(is_train)
    opt['is_train'] = bool(is_train)
    opt['launcher'] = launcher
    opt['local_rank'] = local_rank

    # distributed settings
    if opt['launcher']== 'none':
        opt['dist'] = False
        print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        if opt['launcher'] == 'slurm' and 'dist_params' in opt:
            init_dist(opt['launcher'], **opt['dist_params'])
        else:
            init_dist(opt['launcher'])
            print('init dist .. ', opt['launcher'])

    opt['rank'], opt['world_size'] = get_dist_info()

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    return opt

opt_path = './Options/GaussianGrayDenoising_Restormer.yml'
opt = parse_(opt_path, is_train=True, launcher='none', local_rank=0 )
opt
Disable distributed.
Out[ ]:
{'name': 'GaussianGrayDenoising_Restormer',
 'model_type': 'ImageCleanModel',
 'scale': 1,
 'num_gpu': 1,
 'manual_seed': 100,
 'datasets': {'train': {'phase': 'train',
   'name': 'TrainSet',
   'type': 'Dataset_GaussianDenoising',
   'sigma_type': 'random',
   'sigma_range': [0, 50],
   'in_ch': 1,
   'dataroot_gt': './Datasets/train/DFWB',
   'dataroot_lq': 'none',
   'geometric_augs': True,
   'filename_tmpl': '{}',
   'io_backend': {'type': 'disk'},
   'use_shuffle': True,
   'num_worker_per_gpu': 4,
   'batch_size_per_gpu': 2,
   'mini_batch_sizes': [2, 1, 1, 1, 1, 1],
   'iters': [92000, 64000, 48000, 36000, 36000, 24000],
   'gt_size': 320,
   'gt_sizes': [128, 160, 192, 256, 320],
   'dataset_enlarge_ratio': 1,
   'prefetch_mode': None},
  'val': {'phase': 'val',
   'name': 'ValSet',
   'type': 'Dataset_GaussianDenoising',
   'sigma_test': 25,
   'in_ch': 1,
   'dataroot_gt': './Datasets/test/BSD68',
   'dataroot_lq': 'none',
   'io_backend': {'type': 'disk'}}},
 'network_g': {'type': 'Restormer',
  'inp_channels': 1,
  'out_channels': 1,
  'dim': 48,
  'num_blocks': [4, 6, 6, 8],
  'num_refinement_blocks': 4,
  'heads': [1, 2, 4, 8],
  'ffn_expansion_factor': 2.66,
  'bias': False,
  'LayerNorm_type': 'BiasFree',
  'dual_pixel_task': False},
 'path': {'pretrain_network_g': None,
  'strict_load_g': True,
  'resume_state': None},
 'train': {'total_iter': 300000,
  'warmup_iter': -1,
  'use_grad_clip': True,
  'scheduler': {'type': 'CosineAnnealingRestartCyclicLR',
   'periods': [92000, 208000],
   'restart_weights': [1, 1],
   'eta_mins': [0.0003, 1e-06]},
  'mixing_augs': {'mixup': True, 'mixup_beta': 1.2, 'use_identity': True},
  'optim_g': {'type': 'AdamW',
   'lr': 0.0003,
   'weight_decay': 0.0001,
   'betas': [0.9, 0.999]},
  'pixel_opt': {'type': 'L1Loss', 'loss_weight': 1, 'reduction': 'mean'}},
 'val': {'window_size': 8,
  'val_freq': 4000.0,
  'save_img': False,
  'rgb2bgr': True,
  'use_image': False,
  'max_minibatch': 8,
  'metrics': {'psnr': {'type': 'calculate_psnr',
    'crop_border': 0,
    'test_y_channel': False}}},
 'logger': {'print_freq': 1000,
  'save_checkpoint_freq': 4000.0,
  'use_tb_logger': True,
  'wandb': {'project': None, 'resume_id': None}},
 'dist_params': {'backend': 'nccl', 'port': 29500},
 'is_train': True,
 'launcher': 'none',
 'local_rank': 0,
 'dist': False,
 'rank': 0,
 'world_size': 1}

Training the Novel Hybrid Network¶

In [ ]:
# def parse_options(is_train=True): # parsed config from section 'Model Param'
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         '-opt', type=str, required=True, help='Path to option YAML file.')
#     parser.add_argument(
#         '--launcher',
#         choices=['none', 'pytorch', 'slurm'],
#         default='none',
#         help='job launcher')
#     parser.add_argument('--local_rank', type=int, default=0)
#     args = parser.parse_args()
#     opt = parse(args.opt, is_train=is_train)

#     # distributed settings
#     if args.launcher == 'none':
#         opt['dist'] = False
#         print('Disable distributed.', flush=True)
#     else:
#         opt['dist'] = True
#         if args.launcher == 'slurm' and 'dist_params' in opt:
#             init_dist(args.launcher, **opt['dist_params'])
#         else:
#             init_dist(args.launcher)
#             print('init dist .. ', args.launcher)

#     opt['rank'], opt['world_size'] = get_dist_info()

#     # random seed
#     seed = opt.get('manual_seed')
#     if seed is None:
#         seed = random.randint(1, 10000)
#         opt['manual_seed'] = seed
#     set_random_seed(seed + opt['rank'])

#     return opt


def init_loggers(opt):
    log_file = osp.join('./experiments/logs/',
                        f"train_{opt['name']}_{get_time_str()}.log") # opt['path']['log'] was first argument
    logger = get_root_logger(
        logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
    logger.info(get_env_info())
    logger.info(dict2str(opt))

    # initialize wandb logger before tensorboard logger to allow proper sync:
    if (opt['logger'].get('wandb')
            is not None) and (opt['logger']['wandb'].get('project')
                              is not None) and ('debug' not in opt['name']):
        assert opt['logger'].get('use_tb_logger') is True, (
            'should turn on tensorboard when using wandb')
        init_wandb_logger(opt)
    tb_logger = None
    if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
        tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
    return logger, tb_logger

# Dataset_GaussianDenoising
def create_train_val_dataloader(opt, logger):
    # create train and val dataloaders
    train_loader, val_loader = None, None
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)

            # generalisation tool - scale automatically alongside flipping, padding (bool)
            dataset_opt['scale'] = opt['scale']

            train_set = create_dataset(dataset_opt)
            train_sampler = EnlargedSampler(train_set, opt['world_size'],
                                            opt['rank'], dataset_enlarge_ratio)
            train_loader = create_dataloader(
                train_set,
                dataset_opt,
                num_gpu=opt['num_gpu'],
                dist=opt['dist'],
                sampler=train_sampler,
                seed=opt['manual_seed'])

            num_iter_per_epoch = math.ceil(
                len(train_set) * dataset_enlarge_ratio /
                (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
            total_iters = int(opt['train']['total_iter'])
            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
            logger.info(
                'Training statistics:'
                f'\n\tNumber of train images: {len(train_set)}'
                f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
                f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
                f'\n\tWorld size (gpu number): {opt["world_size"]}'
                f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
                f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')

        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(
                val_set,
                dataset_opt,
                num_gpu=opt['num_gpu'],
                dist=opt['dist'],
                sampler=None,
                seed=opt['manual_seed'])
            logger.info(
                f'Number of val images/folders in {dataset_opt["name"]}: '
                f'{len(val_set)}')
        else:
            raise ValueError(f'Dataset phase {phase} is not recognized.')

    return train_loader, train_sampler, val_loader, total_epochs, total_iters


# def main(): # indent again when complete

# parse options, set distributed setting, set ramdom seed
# opt = parse_options(is_train=True) # opt dict created earlier in this section

# Add the missing 'experiments_root' key to the 'path' dictionary
if 'experiments_root' not in opt['path']:
  opt['path']['experiments_root'] = './experiments/experiments'

# Add the missing 'log' key to the 'path' dictionary
if 'log' not in opt['path']:
  opt['path']['log'] = './experiments/logs'

# Add the missing 'models' key to the 'path' dictionary
if 'models' not in opt['path']:
   opt['path']['models'] = './model_zoo'

# Add the missing 'training_states' key to the 'path' dictionary
if 'training_states' not in opt['path']:
   opt['path']['training_states'] = './experiments/training_states'

# Add the missing 'train val dataset' augmentation scaling factor
opt['datasets']['train']['scale'] = opt['scale']
opt['datasets']['val']['scale'] = opt['scale']

torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True

# automatic resume ..
state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
import os
try:
    states = os.listdir(state_folder_path)
except:
    states = []

resume_state = None
if len(states) > 0:
    max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
    resume_state = os.path.join(state_folder_path, max_state_file)
    opt['path']['resume_state'] = resume_state

# load resume states if necessary
if opt['path'].get('resume_state'):
    device_id = torch.cuda.current_device()
    resume_state = torch.load(
        opt['path']['resume_state'],
        map_location=lambda storage, loc: storage.cuda(device_id))
else:
    resume_state = None

# mkdir for experiments and logger # if error confirm that experiment path doesn't end with /
# sys.path.append('/home/hiran/Restormer')
# sys.path.remove('/home/hiran/Restormer')
# %cd Denoising/
# os.getcwd()

if resume_state is None:
    # # Modified mkdir_and_rename to handle non-existent directories
    # def mkdir_and_rename_modified(path):
    #     if os.path.exists(path):
    #         new_name = path + '_archived_' + get_time_str()
    #         print(f'Path already exists. Rename it to {new_name}', flush=True)
    #         os.rename(path, new_name)
    #     os.makedirs(path, exist_ok=True)

    # # Replace the original mkdir_and_rename with the modified one
    # original_mkdir_and_rename = mkdir_and_rename
    # mkdir_and_rename = mkdir_and_rename_modified

    make_exp_dirs(opt)

    # # Restore the original mkdir_and_rename
    # mkdir_and_rename = original_mkdir_and_rename

    if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
            'name'] and opt['rank'] == 0:
        mkdir_and_rename(osp.join('tb_logger', opt['name']))

# initialize loggers
logger, tb_logger = init_loggers(opt)

# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger) # if scandir error pls. import from above
train_loader, train_sampler, val_loader, total_epochs, total_iters = result

# create model
if resume_state:  # resume training
    check_resume(opt, resume_state['iter'])
    model = create_model(opt)
    model.resume_training(resume_state)  # handle optimizers and schedulers
    logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
                f"iter: {resume_state['iter']}.")
    start_epoch = resume_state['epoch']
    current_iter = resume_state['iter']
else:
    model = create_model(opt)
    start_epoch = 0
    current_iter = 0

# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)

# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
    prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
    prefetcher = CUDAPrefetcher(train_loader, opt)
    logger.info(f'Use {prefetch_mode} prefetch dataloader')
    if opt['datasets']['train'].get('pin_memory') is not True:
        raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
    raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
                      "Supported ones are: None, 'cuda', 'cpu'.")

# training
logger.info(
    f'Start training from epoch: {start_epoch}, iter: {current_iter}')
data_time, iter_time = time.time(), time.time()
start_time = time.time()

# for epoch in range(start_epoch, total_epochs + 1):

iters = opt['datasets']['train'].get('iters')
batch_size = opt['datasets']['train'].get('batch_size_per_gpu')
mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes')
gt_size = opt['datasets']['train'].get('gt_size')
mini_gt_sizes = opt['datasets']['train'].get('gt_sizes')

groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])

logger_j = [True] * len(groups)

scale = opt['scale']

epoch = start_epoch
while current_iter <= total_iters:
    train_sampler.set_epoch(epoch)
    prefetcher.reset()
    train_data = prefetcher.next()

    while train_data is not None:
        data_time = time.time() - data_time

        current_iter += 1
        if current_iter > total_iters:
            break
        # update learning rate
        model.update_learning_rate(
            current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))


        ### ------Progressive learning ---------------------
        j = ((current_iter>groups) !=True).nonzero()[0]
        if len(j) == 0:
            bs_j = len(groups) - 1
        else:
            bs_j = j[0]

        mini_gt_size = mini_gt_sizes[bs_j]
        mini_batch_size = mini_batch_sizes[bs_j]

        if logger_j[bs_j]:
            logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count()))
            logger_j[bs_j] = False

        lq = train_data['lq']
        gt = train_data['gt']

        if mini_batch_size < batch_size:
            indices = random.sample(range(0, batch_size), k=mini_batch_size)
            lq = lq[indices]
            gt = gt[indices]

        if mini_gt_size < gt_size:
            x0 = int((gt_size - mini_gt_size) * random.random())
            y0 = int((gt_size - mini_gt_size) * random.random())
            x1 = x0 + mini_gt_size
            y1 = y0 + mini_gt_size
            lq = lq[:,:,x0:x1,y0:y1]
            gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale]
        ###-------------------------------------------


        model.feed_train_data({'lq': lq, 'gt':gt})
        model.optimize_parameters(current_iter)

        iter_time = time.time() - iter_time
        # log
        if current_iter % opt['logger']['print_freq'] == 0:
            log_vars = {'epoch': epoch, 'iter': current_iter}
            log_vars.update({'lrs': model.get_current_learning_rate()})
            log_vars.update({'time': iter_time, 'data_time': data_time})
            log_vars.update(model.get_current_log())
            msg_logger(log_vars)

        # save models and training states
        if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
            logger.info('Saving models and training states.')
            model.save(epoch, current_iter)

        # validation
        if opt.get('val') is not None and (current_iter %
                                            opt['val']['val_freq'] == 0):
            rgb2bgr = opt['val'].get('rgb2bgr', True)
            # wheather use uint8 image to compute metrics
            use_image = opt['val'].get('use_image', True)
            model.validation(val_loader, current_iter, tb_logger,
                              opt['val']['save_img'], rgb2bgr, use_image )

        data_time = time.time()
        iter_time = time.time()
        train_data = prefetcher.next()
    # end of iter
    epoch += 1

# end of epoch

consumed_time = str(
    datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}')
logger.info('Save the latest model.')
model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
if opt.get('val') is not None:
    model.validation(val_loader, current_iter, tb_logger,
                      opt['val']['save_img'])
if tb_logger:
    tb_logger.close()


# if __name__ == '__main__':
#     main()
In [ ]:
 

Training Transformer NN on real CT scans - Lodopab dataset (80K 362x362 CT scans)¶

Creating the low density CT dataset¶

In [ ]:
LOG_DIR2 = '../../logs/lodopab_restormer_Gaussian'
SAVE_BEST_LEARNED_PARAMS_PATH2 = '../../params/lodopab_restormer_Gaussian'
In [ ]:
# ----------------------------
# load FBP model
# ----------------------------

# from ctexample section
reconstructor = FBPReconstructor(ray_trafo, hyper_params={
    'filter_type': 'Hann',
    'frequency_scaling': 0.8})
In [ ]:
lq_path = '/home/hiran/Restormer/Denoising/Datasets/train_RetrainCT/lq/'
gt_path = '/home/hiran/Restormer/Denoising/Datasets/train_RetrainCT/gt/'

def create_LD_ds(train_ds, lq_path, gt_path):
  idx = 0
  with torch.no_grad(): # save memory by not calc gradient
    for obs, gt in train_ds:
      torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
      torch.cuda.empty_cache()

      # return odl elem H,W 362, 362 with normalized pixel vals
      reco = reconstructor.reconstruct(obs)

      reco_arr = odl_to_single(reco) # return 3d np arr H,W,C:1
      reco_uint = utilsImg.single2uint(reco_arr)
      lq_nm = lq_path+str(idx)+'.png'
      utilsImg.imsave(reco_uint, lq_nm ) # input 3d - np arr

      gt_3d = odl_to_single( gt ) # return 3d np arr H,W,C:1
      gt_uint = utilsImg.single2uint(gt_3d)
      gt_2d = gt_uint.squeeze()
      gt_nm = gt_path+str(idx)+'.png'
      cv2.imwrite(gt_nm, gt_2d) # input 2d - arr
      idx +=1

# create_LD_ds(train_ds,lq_path,gt_path)

val_lq_path = '/home/hiran/Restormer/Denoising/Datasets/validation_RetrainCT/lq/'
val_gt_path = '/home/hiran/Restormer/Denoising/Datasets/validation_RetrainCT/gt/'

create_LD_ds(validation_data, val_lq_path, val_gt_path)

Some Preprocessing steps¶

In [ ]:
# Many incompatibilities in git, better sync with my gdrive
# Clone Restormer
# !git clone https://github.com/swz30/Restormer.git
In [ ]:
# if running in gdrive

# %cd # for cd root
# from google.colab import drive # or use left panel UI
# drive.mount('/content/gdrive')
# os.getcwd()
# %cd /content/drive/MyDrive/Colab Notebooks/Restormer_1209/Denoising

# if in a local setup

if os.getcwd() != '/home/hiran/Restormer/Denoising':
  raise ImportError
# %cd Restormer/Denoising
In [ ]:
# !pip install gdown
# import gdown
# !pip install gdrive
# import gdrive

# !python download_data.py --data train-test --noise gaussian

# shutil.unpack_archive('Datasets/Downloads/Flickr2K.zip', 'Datasets/Downloads')

# shutil.unpack_archive('Datasets/Downloads/DIV2K.zip', 'Datasets/Downloads')
# os.remove('Datasets/Downloads/DIV2K.zip')
# os.remove('Datasets/Downloads/Flickr2K.zip')

# !python generate_patches_dfwb.py
# os.remove('Datasets/Downloads') # REMEMBER to del and save 30GB+

# gaussian_test = '1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0'   ## https://drive.google.com/file/d/1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0/view?usp=sharing

# print('Gaussian Denoising Testing Data!')
# gdown.download(id=gaussian_test, output='Datasets/test.zip', quiet=False)
# os.system(f'gdrive download {gaussian_test} --path Datasets/')
# print('Extracting Data...')
# shutil.unpack_archive('Datasets/test.zip', 'Datasets')
# os.remove('Datasets/test.zip')
Collecting gdrive
  Downloading gdrive-0.1.5-py3-none-any.whl.metadata (814 bytes)
Collecting setuptools~=59.6.0 (from gdrive)
  Downloading setuptools-59.6.0-py3-none-any.whl.metadata (5.0 kB)
Collecting wheel~=0.37.1 (from gdrive)
  Downloading wheel-0.37.1-py2.py3-none-any.whl.metadata (2.3 kB)
Collecting versioneer~=0.22 (from gdrive)
  Downloading versioneer-0.29-py3-none-any.whl.metadata (16 kB)
Collecting argparse~=1.4.0 (from gdrive)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting google-api-python-client~=2.43.0 (from gdrive)
  Downloading google_api_python_client-2.43.0-py2.py3-none-any.whl.metadata (6.6 kB)
Collecting google-auth-oauthlib~=0.5.1 (from gdrive)
  Downloading google_auth_oauthlib-0.5.3-py2.py3-none-any.whl.metadata (2.7 kB)
Requirement already satisfied: SecretStorage~=3.3.1 in /usr/local/lib/python3.12/dist-packages (from gdrive) (3.3.3)
Requirement already satisfied: httplib2<1dev,>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client~=2.43.0->gdrive) (0.30.0)
Requirement already satisfied: google-auth<3.0.0dev,>=1.16.0 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client~=2.43.0->gdrive) (2.38.0)
Requirement already satisfied: google-auth-httplib2>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client~=2.43.0->gdrive) (0.2.0)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client~=2.43.0->gdrive) (2.25.1)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client~=2.43.0->gdrive) (4.2.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from google-auth-oauthlib~=0.5.1->gdrive) (2.0.0)
Requirement already satisfied: cryptography>=2.0 in /usr/local/lib/python3.12/dist-packages (from SecretStorage~=3.3.1->gdrive) (43.0.3)
Requirement already satisfied: jeepney>=0.6 in /usr/local/lib/python3.12/dist-packages (from SecretStorage~=3.3.1->gdrive) (0.9.0)
Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.12/dist-packages (from cryptography>=2.0->SecretStorage~=3.3.1->gdrive) (1.17.1)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (1.70.0)
Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (5.29.5)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (1.26.1)
Requirement already satisfied: requests<3.0.0,>=2.18.0 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (2.32.4)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth<3.0.0dev,>=1.16.0->google-api-python-client~=2.43.0->gdrive) (5.5.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth<3.0.0dev,>=1.16.0->google-api-python-client~=2.43.0->gdrive) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth<3.0.0dev,>=1.16.0->google-api-python-client~=2.43.0->gdrive) (4.9.1)
Requirement already satisfied: pyparsing<4,>=3.0.4 in /usr/local/lib/python3.12/dist-packages (from httplib2<1dev,>=0.15.0->google-api-python-client~=2.43.0->gdrive) (3.2.3)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.12/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib~=0.5.1->gdrive) (3.3.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.12->cryptography>=2.0->SecretStorage~=3.3.1->gdrive) (2.22)
Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3.0.0dev,>=1.16.0->google-api-python-client~=2.43.0->gdrive) (0.6.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests<3.0.0,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests<3.0.0,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests<3.0.0,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests<3.0.0,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client~=2.43.0->gdrive) (2025.8.3)
Downloading gdrive-0.1.5-py3-none-any.whl (8.0 kB)
Downloading argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Downloading google_api_python_client-2.43.0-py2.py3-none-any.whl (8.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.3/8.3 MB 97.3 MB/s eta 0:00:00
Downloading google_auth_oauthlib-0.5.3-py2.py3-none-any.whl (19 kB)
Downloading setuptools-59.6.0-py3-none-any.whl (952 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 952.6/952.6 kB 53.7 MB/s eta 0:00:00
Downloading versioneer-0.29-py3-none-any.whl (46 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.8/46.8 kB 3.3 MB/s eta 0:00:00
Downloading wheel-0.37.1-py2.py3-none-any.whl (35 kB)
Installing collected packages: argparse, wheel, versioneer, setuptools, google-auth-oauthlib, google-api-python-client, gdrive
  Attempting uninstall: wheel
    Found existing installation: wheel 0.45.1
    Uninstalling wheel-0.45.1:
      Successfully uninstalled wheel-0.45.1
  Attempting uninstall: setuptools
    Found existing installation: setuptools 75.2.0
    Uninstalling setuptools-75.2.0:
      Successfully uninstalled setuptools-75.2.0
  Attempting uninstall: google-auth-oauthlib
    Found existing installation: google-auth-oauthlib 1.2.2
    Uninstalling google-auth-oauthlib-1.2.2:
      Successfully uninstalled google-auth-oauthlib-1.2.2
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 2.181.0
    Uninstalling google-api-python-client-2.181.0:
      Successfully uninstalled google-api-python-client-2.181.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.
odl 0.8.3 requires setuptools>=65.6, but you have setuptools 59.6.0 which is incompatible.
google-adk 1.13.0 requires google-api-python-client<3.0.0,>=2.157.0, but you have google-api-python-client 2.43.0 which is incompatible.
pandas-gbq 0.29.2 requires google-auth-oauthlib>=0.7.0, but you have google-auth-oauthlib 0.5.3 which is incompatible.
torch 2.8.0+cu126 requires nvidia-cuda-runtime-cu12==12.6.77; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-runtime-cu12 12.5.82 which is incompatible.
torch 2.8.0+cu126 requires nvidia-cufft-cu12==11.3.0.4; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cufft-cu12 11.2.3.61 which is incompatible.
arviz 0.22.0 requires setuptools>=60.0.0, but you have setuptools 59.6.0 which is incompatible.
Successfully installed argparse-1.4.0 gdrive-0.1.5 google-api-python-client-2.43.0 google-auth-oauthlib-0.5.3 setuptools-59.6.0 versioneer-0.29 wheel-0.37.1

Parameters - For Training Hybrid Transformer Model¶

In [ ]:
# this is a settings file, no use in running in plain colab. it's hyperparam, param in a .YML
# sync changes from below cell to .YML at ./Options/GaussianGrayDenoising_Restormer_RetrainCT.YML

# general settings
name: GrayDenoising_Restormer_RetrainCT4
model_type: ImageCleanModel
scale: 1
num_gpu: 1 # 8  # set num_gpu: 0 for cpu mode
manual_seed: 100

# dataset and data loader settings
datasets:
  train:
    phase: train # added since, missing param for def train dataloader
    name: TrainSet
    type: Dataset_GaussianDenoising
    sigma_type: random
    sigma_range: [0,50]
    in_ch: 1    ## Grayscale image
    dataroot_gt: ./Datasets/train_RetrainCT/gt # ct loc
    dataroot_lq: ./Datasets/train_RetrainCT/lq # ct loc
    geometric_augs: true

    filename_tmpl: '{}'
    io_backend:
      type: disk

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 4 # 8
    batch_size_per_gpu: 4 # 2 # 8

    ### -------------Progressive training--------------------------
    mini_batch_sizes: [4,2,1,1] # [2,1,1,1,1,1] # [8,5,4,2,1,1] # for 40GB VRAM? Batch size per gpu
    iters: [12000,8000,4000,2000] # [92000,64000,48000,36000] # [92000,64000,48000,36000,36000,24000]
    gt_size: 256 # 320 OOM err # 384   # Max patch size for progressive training
    gt_sizes: [128,160,192,256] # [128,160,192,256,320] OOM err # [128,160,192,256,320,384]  # Patch sizes for progressive training.
    ### ------------------------------------------------------------

    ### ------- Training on single fixed-patch size 128x128---------
    # mini_batch_sizes: [4] # [8] for 40GB VRAM?
    # iters: [12000] # [24000]for experimenting # [300000] # original setup
    # gt_size: 128
    # gt_sizes: [128]
    ### ------------------------------------------------------------

    dataset_enlarge_ratio: 1
    prefetch_mode: ~ # cuda # ~
    # pin_memory: true # new experiment

  val:
    phase: val # added since, missing param for def train dataloader
    name: ValSet
    type: Dataset_GaussianDenoising
    sigma_test: 25
    in_ch: 1  ## Grayscale image
    dataroot_gt: ./Datasets/validation_RetrainCT/gt # ./Datasets/test/BSD68
    dataroot_lq: ./Datasets/validation_RetrainCT/lq # none
    io_backend:
      type: disk

# network structures
network_g:
  type: Restormer
  inp_channels: 1
  out_channels: 1
  dim: 48
  num_blocks: [4,6,6,8]
  num_refinement_blocks: 4
  heads: [1,2,4,8]
  ffn_expansion_factor: 2.66
  bias: False
  LayerNorm_type: BiasFree
  dual_pixel_task: False


# path
path:
  pretrain_network_g: ~
  strict_load_g: true
  resume_state: ~

# training settings
train:
  total_iter: 26000 # 26K for progressive 12000 for fixed patch # 24000 # for experimenting single fixed-patch # 240000 for 16GB VRAM progressive # 300000 # Original setup
  warmup_iter: -1 # no warm up
  use_grad_clip: true

  # Split 300k iterations into two cycles.
  # 1st cycle: fixed 3e-4 LR for 92k iters.
  # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
  scheduler:
    type: CosineAnnealingRestartCyclicLR
    periods: [8000, 18000] # [3680, 8320] # [7360, 16640] # for experimenting single patch # [92000, 208000] # original setup
    restart_weights: [1,1]
    eta_mins: [0.0003,0.000001]

  mixing_augs:
    mixup: true
    mixup_beta: 1.2
    use_identity: true

  optim_g:
    type: AdamW
    lr: !!float 3e-4
    weight_decay: !!float 1e-4
    betas: [0.9, 0.999]

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1
    reduction: mean

# validation settings
val:
  window_size: 8
  val_freq: !!float 4e3
  save_img: false # true
  rgb2bgr: true
  use_image: true # false original
  max_minibatch: 4 # 3 # 8

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 0
      test_y_channel: false

# logging settings
logger:
  print_freq: 1000
  save_checkpoint_freq: !!float 6e3 # 4e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500
In [ ]:
# function to bypass argparse since training in Colab

sys.path.append('/home/hiran/Restormer') # local
# sys.path.append('/content/drive/MyDrive/Colab Notebooks/Restormer_1209') # gdrive

import datetime
import logging
import math
import time
# import os # Import the os module
from os import path as osp

from torch.utils import data as data
from torchvision.transforms.functional import normalize
# %cd Denoising/

# direct basicsr if in ./Denoising
# !pip install lmdb
from basicsr.data import create_dataloader, create_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import create_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
                           get_root_logger, get_time_str, init_tb_logger,
                           init_wandb_logger, make_exp_dirs, mkdir_and_rename,
                           set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse

from basicsr.data.data_util import (paired_paths_from_folder,paired_DP_paths_from_folder,paired_paths_from_lmdb,paired_paths_from_meta_info_file,paths_from_lmdb)
from basicsr.data.transforms import augment, paired_random_crop, paired_random_crop_DP,random_augmentation
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, padding_DP,imfrombytesDP,scandir
import importlib
# importlib.reload(basicr.utils.scandir) # not a module, error

def parse_(opt_path: str, is_train: bool = True, launcher='none', local_rank=0) -> dict:
    """
    Read the YAML at opt_path and return a dict-like `opt`.
    is intentionally simple: it loads YAML via yaml.safe_load and returns the dict.
    can extend this to validate required fields, set defaults, or convert types.
    """
    p = Path(opt_path)
    if not p.exists():
        raise FileNotFoundError(f"YAML options file not found: {opt_path}")
    with p.open('r') as f:
        opt = yaml.safe_load(f)
    # ensure we have a dict
    if not isinstance(opt, dict):
        raise ValueError("Parsed YAML is not a mapping (dict).")
    # attach the is_train flag for convenience
    # opt['_is_train'] = bool(is_train)
    opt['is_train'] = bool(is_train)
    opt['launcher'] = launcher
    opt['local_rank'] = local_rank

    # distributed settings
    if opt['launcher']== 'none':
        opt['dist'] = False
        # print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        if opt['launcher'] == 'slurm' and 'dist_params' in opt:
            init_dist(opt['launcher'], **opt['dist_params'])
        else:
            init_dist(opt['launcher'])
            print('init dist .. ', opt['launcher'])

    opt['rank'], opt['world_size'] = get_dist_info()

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    return opt

opt_path = './Options/GrayDenoising_Restormer_RetrainCT4.yml'
del opt
opt = parse_(opt_path, is_train=True, launcher='none', local_rank=0 )
opt
Out[ ]:
{'name': 'GrayDenoising_Restormer_RetrainCT4',
 'model_type': 'ImageCleanModel',
 'scale': 1,
 'num_gpu': 1,
 'manual_seed': 100,
 'datasets': {'train': {'phase': 'train',
   'name': 'TrainSet',
   'type': 'Dataset_GaussianDenoising',
   'sigma_type': 'random',
   'sigma_range': [0, 50],
   'in_ch': 1,
   'dataroot_gt': './Datasets/train_RetrainCT/gt',
   'dataroot_lq': './Datasets/train_RetrainCT/lq',
   'geometric_augs': True,
   'filename_tmpl': '{}',
   'io_backend': {'type': 'disk'},
   'use_shuffle': True,
   'num_worker_per_gpu': 4,
   'batch_size_per_gpu': 4,
   'mini_batch_sizes': [4, 2, 1, 1],
   'iters': [12000, 8000, 4000, 2000],
   'gt_size': 256,
   'gt_sizes': [128, 160, 192, 256],
   'dataset_enlarge_ratio': 1,
   'prefetch_mode': None},
  'val': {'phase': 'val',
   'name': 'ValSet',
   'type': 'Dataset_GaussianDenoising',
   'sigma_test': 25,
   'in_ch': 1,
   'dataroot_gt': './Datasets/validation_RetrainCT/gt',
   'dataroot_lq': './Datasets/validation_RetrainCT/lq',
   'io_backend': {'type': 'disk'}}},
 'network_g': {'type': 'Restormer',
  'inp_channels': 1,
  'out_channels': 1,
  'dim': 48,
  'num_blocks': [4, 6, 6, 8],
  'num_refinement_blocks': 4,
  'heads': [1, 2, 4, 8],
  'ffn_expansion_factor': 2.66,
  'bias': False,
  'LayerNorm_type': 'BiasFree',
  'dual_pixel_task': False},
 'path': {'pretrain_network_g': None,
  'strict_load_g': True,
  'resume_state': None},
 'train': {'total_iter': 26000,
  'warmup_iter': -1,
  'use_grad_clip': True,
  'scheduler': {'type': 'CosineAnnealingRestartCyclicLR',
   'periods': [8000, 18000],
   'restart_weights': [1, 1],
   'eta_mins': [0.0003, 1e-06]},
  'mixing_augs': {'mixup': True, 'mixup_beta': 1.2, 'use_identity': True},
  'optim_g': {'type': 'AdamW',
   'lr': 0.0003,
   'weight_decay': 0.0001,
   'betas': [0.9, 0.999]},
  'pixel_opt': {'type': 'L1Loss', 'loss_weight': 1, 'reduction': 'mean'}},
 'val': {'window_size': 8,
  'val_freq': 4000.0,
  'save_img': False,
  'rgb2bgr': True,
  'use_image': True,
  'max_minibatch': 4,
  'metrics': {'psnr': {'type': 'calculate_psnr',
    'crop_border': 0,
    'test_y_channel': False}}},
 'logger': {'print_freq': 1000,
  'save_checkpoint_freq': 6000.0,
  'use_tb_logger': True,
  'wandb': {'project': None, 'resume_id': None}},
 'dist_params': {'backend': 'nccl', 'port': 29500},
 'is_train': True,
 'launcher': 'none',
 'local_rank': 0,
 'dist': False,
 'rank': 0,
 'world_size': 1}
In [ ]:
# some new configs specific to below setup which we didn't place inside YML

# Add the missing 'experiments_root' key to the 'path' dictionary
if 'experiments_root' not in opt['path']:
  opt['path']['experiments_root'] = f'./experiments/{opt["name"]}'

# Add the missing 'log' key to the 'path' dictionary
if 'log' not in opt['path']:
  opt['path']['log'] = f"{opt['path']['experiments_root']}/logs"

# Add the missing 'models' key to the 'path' dictionary
if 'models' not in opt['path']:
   opt['path']['models'] = f"./model_zoo/{opt['name']}"

# Add the missing 'training_states' key to the 'path' dictionary
if 'training_states' not in opt['path']:
   opt['path']['training_states'] = f"{opt['path']['experiments_root']}/training_states"

if 'visualization' not in opt['path']:
    opt['path']['visualization'] = f"{opt['path']['experiments_root']}/visuals"

# Add the missing augmentation scaling factor parameter specified to 'train & val dataset'
opt['datasets']['train']['scale'] = opt['scale']
opt['datasets']['val']['scale'] = opt['scale']

Training the Novel Hybrid Network¶

In [ ]:
def init_loggers(opt):
    log_file = osp.join(opt['path']['log'],
                        f"train_{opt['name']}_{get_time_str()}.log") # for first argument './experiments/logs/' works as well
    logger = get_root_logger(
        logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
    logger.info(get_env_info())
    logger.info(dict2str(opt))

    # initialize wandb logger before tensorboard logger to allow proper sync:
    if (opt['logger'].get('wandb')
            is not None) and (opt['logger']['wandb'].get('project')
                              is not None) and ('debug' not in opt['name']):
        assert opt['logger'].get('use_tb_logger') is True, (
            'should turn on tensorboard when using wandb')
        init_wandb_logger(opt)
    tb_logger = None
    if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
        tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
    return logger, tb_logger

# Dataset_GaussianDenoising
def create_train_val_dataloader(opt, logger):
    # create train and val dataloaders
    train_loader, val_loader = None, None
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)

            # generalisation tool - scale automatically alongside flipping, padding (bool)
            dataset_opt['scale'] = opt['scale']

            train_set = create_dataset(dataset_opt)
            train_sampler = EnlargedSampler(train_set, opt['world_size'],
                                            opt['rank'], dataset_enlarge_ratio)
            train_loader = create_dataloader(
                train_set,
                dataset_opt,
                num_gpu=opt['num_gpu'],
                dist=opt['dist'],
                sampler=train_sampler,
                seed=opt['manual_seed'])

            num_iter_per_epoch = math.ceil(
                len(train_set) * dataset_enlarge_ratio /
                (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
            total_iters = int(opt['train']['total_iter'])
            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
            logger.info(
                'Training statistics:'
                f'\n\tNumber of train images: {len(train_set)}'
                f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
                f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
                f'\n\tWorld size (gpu number): {opt["world_size"]}'
                f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
                f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')

        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(
                val_set,
                dataset_opt,
                num_gpu=opt['num_gpu'],
                dist=opt['dist'],
                sampler=None,
                seed=opt['manual_seed'])
            logger.info(
                f'Number of val images/folders in {dataset_opt["name"]}: '
                f'{len(val_set)}')
        else:
            raise ValueError(f'Dataset phase {phase} is not recognized.')

    return train_loader, train_sampler, val_loader, total_epochs, total_iters


# def main(): # indent again when complete

# parse options, set distributed setting, set ramdom seed
# opt = parse_options(is_train=True) # opt dict created earlier in this section

torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True

# automatic resume ..
# state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
state_folder_path = '{}/{}/'.format(opt['path']['training_states'],opt['name'])

# import os
try:
    states = os.listdir(state_folder_path)
except:
    states = []

resume_state = None
if len(states) > 0:
    max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
    resume_state = os.path.join(state_folder_path, max_state_file)
    opt['path']['resume_state'] = resume_state

# load resume states if necessary
if opt['path'].get('resume_state'):
    device_id = torch.cuda.current_device()
    resume_state = torch.load(
        opt['path']['resume_state'],
        map_location=lambda storage, loc: storage.cuda(device_id))
else:
    resume_state = None

if resume_state is None:
    # # Modified mkdir_and_rename to handle non-existent directories
    # def mkdir_and_rename_modified(path):
    #     if os.path.exists(path):
    #         new_name = path + '_archived_' + get_time_str()
    #         print(f'Path already exists. Rename it to {new_name}', flush=True)
    #         os.rename(path, new_name)
    #     os.makedirs(path, exist_ok=True)

    # # Replace the original mkdir_and_rename with the modified one
    # original_mkdir_and_rename = mkdir_and_rename
    # mkdir_and_rename = mkdir_and_rename_modified

    make_exp_dirs(opt)

    # # Restore the original mkdir_and_rename
    # mkdir_and_rename = original_mkdir_and_rename

    if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
            'name'] and opt['rank'] == 0:
        mkdir_and_rename(osp.join('tb_logger', opt['name']))

# initialize loggers
logger, tb_logger = init_loggers(opt)

# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger) # if scandir error pls. import from above
train_loader, train_sampler, val_loader, total_epochs, total_iters = result

# create model
if resume_state:  # resume training
    check_resume(opt, resume_state['iter'])
    model = create_model(opt)
    model.resume_training(resume_state)  # handle optimizers and schedulers
    logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
                f"iter: {resume_state['iter']}.")
    start_epoch = resume_state['epoch']
    current_iter = resume_state['iter']
else:
    model = create_model(opt)
    start_epoch = 0
    current_iter = 0

# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)

# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
    prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
    prefetcher = CUDAPrefetcher(train_loader, opt)
    logger.info(f'Use {prefetch_mode} prefetch dataloader')
    if opt['datasets']['train'].get('pin_memory') is not True:
        raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
    raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
                      "Supported ones are: None, 'cuda', 'cpu'.")

# training
logger.info(
    f'Start training from epoch: {start_epoch}, iter: {current_iter}')
data_time, iter_time = time.time(), time.time()
start_time = time.time()

# for epoch in range(start_epoch, total_epochs + 1):

iters = opt['datasets']['train'].get('iters')
batch_size = opt['datasets']['train'].get('batch_size_per_gpu')
mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes')
gt_size = opt['datasets']['train'].get('gt_size')
mini_gt_sizes = opt['datasets']['train'].get('gt_sizes')

groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])

logger_j = [True] * len(groups)

scale = opt['scale']

epoch = start_epoch
while current_iter <= total_iters:
    train_sampler.set_epoch(epoch)
    prefetcher.reset()
    train_data = prefetcher.next()

    while train_data is not None:
        data_time = time.time() - data_time

        current_iter += 1
        if current_iter > total_iters:
            break
        # update learning rate
        model.update_learning_rate(
            current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))


        ### ------Progressive learning ---------------------
        j = ((current_iter>groups) !=True).nonzero()[0]
        if len(j) == 0:
            bs_j = len(groups) - 1
        else:
            bs_j = j[0]

        mini_gt_size = mini_gt_sizes[bs_j]
        mini_batch_size = mini_batch_sizes[bs_j]

        if logger_j[bs_j]:
            logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count()))
            logger_j[bs_j] = False

        lq = train_data['lq']
        gt = train_data['gt']

        if mini_batch_size < batch_size:
            indices = random.sample(range(0, batch_size), k=mini_batch_size)
            lq = lq[indices]
            gt = gt[indices]

        if mini_gt_size < gt_size:
            x0 = int((gt_size - mini_gt_size) * random.random())
            y0 = int((gt_size - mini_gt_size) * random.random())
            x1 = x0 + mini_gt_size
            y1 = y0 + mini_gt_size
            lq = lq[:,:,x0:x1,y0:y1]
            gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale]
        ###-------------------------------------------


        model.feed_train_data({'lq': lq, 'gt':gt})
        model.optimize_parameters(current_iter)

        iter_time = time.time() - iter_time
        # log
        if current_iter % opt['logger']['print_freq'] == 0:
            log_vars = {'epoch': epoch, 'iter': current_iter}
            log_vars.update({'lrs': model.get_current_learning_rate()})
            log_vars.update({'time': iter_time, 'data_time': data_time})
            log_vars.update(model.get_current_log())
            msg_logger(log_vars)

        # save models and training states
        if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
            logger.info('Saving models and training states.')
            model.save(epoch, current_iter)

        # validation
        if opt.get('val') is not None and (current_iter %
                                            opt['val']['val_freq'] == 0):
            rgb2bgr = opt['val'].get('rgb2bgr', True)
            # wheather use uint8 image to compute metrics
            use_image = opt['val'].get('use_image', True)
            model.validation(val_loader, current_iter, tb_logger,
                              opt['val']['save_img'], rgb2bgr, use_image )

        data_time = time.time()
        iter_time = time.time()
        train_data = prefetcher.next()
    # end of iter
    epoch += 1

# end of epoch

consumed_time = str(
    datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}')
logger.info('Save the latest model.')
model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
if opt.get('val') is not None:
    model.validation(val_loader, current_iter, tb_logger,
                      opt['val']['save_img'])
if tb_logger:
    tb_logger.close()


# if __name__ == '__main__':
#     main()

Final Hybrid Model with Transformer NN/ CT denoising on top of FBP Model¶

Apply Complete Hybrid Blind Gausian Restormer on top of the FBP model¶

In [ ]:
def eval_ctrecn (test_data_, hybrid_model):
  """ input dataset, 2nd hybrid_model to put on top of FBP model returns 2 lists recos_, psnrs_ of whole ds"""
  recos_=[]
  psnrs_=[]
  #%% evaluate
  img_multiple_of = 8
  # from ctexample section
  reconstructor = FBPReconstructor(ray_trafo, hyper_params={
      'filter_type': 'Hann',
      'frequency_scaling': 0.8})

  with torch.no_grad(): # save memory by not calc gradient
    for obs, gt in test_data_:
        torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
        torch.cuda.empty_cache()

        reco = reconstructor.reconstruct(obs) # return odl elem H,W with normalized pixel vals
        # applying the transition layer and Denoicing model on top
        reco = odl_to_single(reco) # return 3d np arr H,W,C:1
        input_ = torch.from_numpy(reco).permute(2,0,1).unsqueeze(0).cuda() # shape is 1B,3or1C,H,W

        # Pad the input if not_multiple_of 8 # pad height and width
        h,w = input_.shape[2], input_.shape[3]
        H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-h if h%img_multiple_of!=0 else 0
        padw = W-w if w%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

        # inpput shape torch.Size([1B, 3or1C, 368, 368]) # pixels normalized 0,1
        restored_ = hybrid_model(input_)
        # out shape torch.Size([1, 3, 368, 368]) # pixels normalized 0,1
        restored_ = torch.clamp(restored_, 0, 1) # set floor ceiling for pix vals
        restored_ = restored_[0,0,:h,:w] # discard RGB channel, Batch dim will suffice for evaluation
        reco2 = restored_.cpu().detach().numpy() # change shape from B,C,H,W to B, H, W, C then move to CPU since tensor to numpy later, conf no gradients still attached
        # resume default pipeline reco2 can be H,W,1
        recos_.append(reco2)
        # gt = utilsImg.modcrop(gt, 8) # match shape of reco2 modcropped in DruNet
        psnrs_.append(PSNR(reco2, gt))

  print('whole eval mean psnr: {:f}'.format(np.mean(psnrs_)))
  return recos_, psnrs_

# try:
#   del recos_
#   del psnrs_
# except:
#   print('no recos2')
# finally:
#   recos_ = []
#   psnrs_ = []

Saving as a Class & Compile the 2nd Hybrid Neural Network¶

In [24]:
if os.getcwd() != '/home/hiran/Restormer/Denoising':
  %cd '/home/hiran/Restormer/Denoising'
/home/hiran/Restormer/Denoising
In [25]:
# !ls
# os.getcwd() # confirm cwd is Denoiser
# sys.path.insert(0,os.getcwd())
# utils.load_gray_img? # confirm functions are imported
# sys.modules.pop("utils",None) # if incorrect utils is loaded
# sys.path.append('basicsr')

# Compile Model

def hybrid_model(weights, yaml_file):
  model_type = 'blind' # or 'non_blind'
  sigmas_str = '15,25,50' # Define sigma values as a string

  # weights_path = './pretrained_models/gaussian_gray_denoising' # original weights
  # weights_path = './model_zoo/GrayDenoising_Restormer_RetrainCT4/' # retrained weights - specifying folder with /

  ####### Load yaml #######
  # if model_type == 'blind':
  #     # yaml_file = 'Options/GaussianGrayDenoising_Restormer.yml' # original
  #     yaml_file = 'Options/GrayDenoising_Restormer_RetrainCT4.yml' # retrained
  # else:
  #     # Assuming sigma for non_blind is the first one if multiple are given
  #     yaml_file = f'Options/GaussianGrayDenoising_RestormerSigma{sigmas_str.split(",")[0]}.yml'

  try:
      from yaml import CLoader as Loader
  except ImportError:
      from yaml import Loader

  x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) # Loader is just a yaml loader
  s = x['network_g'].pop('type') # Restormer class doesn't have a type param in cls init
  model_restoration = Restormer(**x['network_g'])

  # if model_type == 'blind':
  #     # weights = weights_path+'_blind.pth' # original weights
  #     weights = weights_path+'net_g_24000.pth' # retrained weights net_g_12000
  # else:
  #     weights = weights_path + '_sigma' + str(sigma_test) +'.pth'

  checkpoint = torch.load(weights)
  model_restoration.load_state_dict(checkpoint['params'])
  model_restoration.cuda()
  # model_restoration = nn.DataParallel(model_restoration) # this changes model type to DataParallel
  model_restoration.eval()
  return model_restoration

##############

class hybrid_model_cls(Restormer):
    def __init__(self, weights, yaml_file, name='not_defined'):
        # Restormer.__init__()
        super(hybrid_model_cls, self).__init__()

        self.name = name
        self.weights = weights
        self.yaml_file = yaml_file
        model_type = 'blind' # or 'non_blind'

        try:
            from yaml import CLoader as Loader
        except ImportError:
            from yaml import Loader

        x = yaml.load(open(self.yaml_file, mode='r'), Loader=Loader) # Loader is just a yaml loader
        s = x['network_g'].pop('type') # Restormer class doesn't have a type param in cls init
        model_restoration = Restormer(**x['network_g'])
        checkpoint = torch.load(self.weights)
        model_restoration.load_state_dict(checkpoint['params'])
        model_restoration.cuda()
        model_restoration = nn.DataParallel(model_restoration)
        self.model = model_restoration.eval()

        ########## building the pipeline

    def reconstruct(self, obs, fbp_model = reconstructor_lodopab): # define _reconstruct if wwant to default to Restormer reconstruct attrib
      self.obs = obs
      self.fbp_model = fbp_model
      img_multiple_of = 8
      # from ctexample section

      with torch.no_grad(): # save memory by not calc gradient
        torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
        torch.cuda.empty_cache()

        reco = self.fbp_model.reconstruct(self.obs) # return odl elem H,W with normalized pixel vals
        # applying the transition layer and Denoicing model on top
        reco = odl_to_single(reco) # return 3d np arr H,W,C:1
        input_ = torch.from_numpy(reco).permute(2,0,1).unsqueeze(0).cuda() # shape is 1B,3or1C,H,W

        # Pad the input if not_multiple_of 8 # pad height and width
        h,w = input_.shape[2], input_.shape[3]
        H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-h if h%img_multiple_of!=0 else 0
        padw = W-w if w%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

        # inpput shape torch.Size([1B, 3or1C, 368, 368]) # pixels normalized 0,1
        restored_ = self.model(input_)
        # out shape torch.Size([1, 3, 368, 368]) # pixels normalized 0,1
        restored_ = torch.clamp(restored_, 0, 1) # set floor ceiling for pix vals
        restored_ = restored_[0,0,:h,:w] # discard RGB channel, Batch dim will suffice for evaluation
        reco2 = restored_.cpu().detach()
        # self.reco = reco2.numpy()

        # if you want the reconstruction as a DiscretizedSpaceElement, uncomment below
        # reco_space=ray_trafo.domain
        # self.reco = reco_space.element(reco2)
        self.reco = uniform_discr_element(reco2) # evaluation task table compliance
      return self.reco

weights1 = './model_zoo/GrayDenoising_Restormer_RetrainCT/net_g_24000.pth'
yaml_file1 = 'Options/GrayDenoising_Restormer_RetrainCT.yml'

hybrid_model_t1 = hybrid_model_cls(weights1, yaml_file1, name= "Hybrid Transformer Model1")
In [ ]:
# testing pipeline

if 'recos2' in locals():
  del recos2
  del psnrs2
recos2 = []
psnrs2 = []

for obs, gt in test_data_10:
    torch.cuda.ipc_collect() # collects unnecessary inter-process comm.s and free VRAM
    torch.cuda.empty_cache()
    # reco = reconstructor.reconstruct(obs) # return odl elem H,W with normalized pixel vals
    # hybrid_model_1.reconstruct(reco)

    # locate new model in class directly to projection
    reco2 = hybrid_model_t1.reconstruct(obs)

    # resume default pipeline
    recos2.append(reco2)
    psnrs2.append(PSNR(reco2, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs2)))

for i in range(3):
    _, ax = plot_images([recos2[i], test_data_10.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs2[i]))
    ax[0].set_title('Hybrid_Transformer_Reconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))

# del hybrid_model_cls
mean psnr: 30.579241
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Evaluation¶

In [26]:
weightsU8 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU8 = 'Hybrid UNet Residual model sigma 8'
hybrid_model_u8 = hybrid_model_UNetRes(weightsU8, nameU8, sigma = 8)

weightsU15 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU15 = 'Hybrid UNet Residual model sigma 15'
hybrid_model_u15 = hybrid_model_UNetRes(weightsU15, nameU15, sigma = 15)

weightsU25 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU25 = 'Hybrid UNet Residual model sigma 25'
hybrid_model_u25 = hybrid_model_UNetRes(weightsU25, nameU25, sigma = 25)

weightsU35 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU35 = 'Hybrid UNet Residual model sigma 35'
hybrid_model_u35 = hybrid_model_UNetRes(weightsU35, nameU35, sigma = 35)

weightsU42 = '/home/hiran/model_zoo/drunet_gray.pth'
nameU42 = 'Hybrid UNet Residual model sigma 42'
hybrid_model_u42 = hybrid_model_UNetRes(weightsU42, nameU42, sigma = 42)

weights0 = '/home/hiran/Restormer/Denoising/model_zoo/net_g_240000.pth'
yaml_file0 = 'Options/GaussianGrayDenoising_Restormer.yml'

weights1 = '/home/hiran/Restormer/Denoising/model_zoo/GrayDenoising_Restormer_RetrainCT/net_g_24000.pth'
yaml_file1 = 'Options/GrayDenoising_Restormer_RetrainCT.yml'

weights4 = '/home/hiran/Restormer/Denoising/model_zoo/GrayDenoising_Restormer_RetrainCT4/net_g_24000.pth'
yaml_file4 = 'Options/GrayDenoising_Restormer_RetrainCT4.yml'

weights5 = '/home/hiran/Restormer/Denoising/pretrained_models/gaussian_gray_denoising_blind.pth'
yaml_file5 = 'Options/GaussianGrayDenoising_Restormer - original backup.yml'

# del weights, yaml_file

hybrid_model_t0 = hybrid_model_cls(weights0, yaml_file0, name= "Hbrd Transf0-CT Progr Training 300K Itr")
hybrid_model_t1 = hybrid_model_cls(weights1, yaml_file1, name= "Hbrd Transf1-CT retrn fixPatch 240K Itr")
hybrid_model_t4 = hybrid_model_cls(weights4, yaml_file4, name= "Hbrd Transf4-CT Progr Training 24K Itr")
hybrid_model_t5 = hybrid_model_cls(weights5, yaml_file5, name= "Hbrd Transf5-Pre_Trained_blind")

# recos2, psnrs_ = eval_ctrecn (test_data_2, hybrid_model_1) # psnr 30.221161 for 2, 30.108349 for 100
# plot_ctrecn(test_data_2, recos2, psnrs_, visuals = 2)
In [ ]:
# test_set_2
# psnr 29.048862 retrainCT4 12000 progressive 24K itr, but 12 is the best
# psnr 28.960581 retrainCT3 12000 fix patch quick
# psnr 30.221161 retrainCT 24000 fix patch
# best psnr seen - 31.39...

# test_set_10
# last psnr - 31.395075 retrainCT4 12000
# 32 or 31.251339 retrainCT3
# 31.046551,
# best psnr seen - 31.39...

All model Evaluation¶

In [ ]:
# experimenting to get the models perform on test_data, test_data_2

np.random.seed(0) # to make pois noise in obs is consistent throughout reconstructors; fair evaluation hence.

# obs0, gt0 = test_data [0]
# test_pair_0 = DataPairs(obs0, gt0, name='test_pair_0')

# lodopab, ellipses doesn't need to know reco space since ray_trafo obtainable .get_ra_trafo
# shepp-logan phantom
reco_space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300],dtype='float32')

#######  Ray transformations, uncomment according to dataset tested
# radon transform function ( build sinogram from a ct scan )

ray_trafo = dataset.get_ray_trafo(impl=IMPL) # - for Lodopab dataset
# ray_trafo = dataset_ellipses.get_ray_trafo(impl=IMPL) # - for ellipsis dataset
# ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cuda') # shepp-logan

# phantom = odl.phantom.shepp_logan(reco_space, modified=True) # import standard scientific sample ct named shpp logan. that's our gt
# ground_truth = phantom

# geometry = odl.tomo.cone_beam_geometry(reco_space, 40, 40, 360) # build our ct machine geometry using odl # object = human cross section space, source = ray emmiter radius from origin = human , likewise radius from origin to detect, optional no.of angles in our geometry

# proj_data = ray_trafo(phantom) # call the function to build the sinogram aka projection
# observation = (proj_data + np.random.poisson(0.3, proj_data.shape)).asarray()
# test_data_shepp = DataPairs(observation, ground_truth, name='shepp-logan + pois')

# %% task table and reconstructors
eval_tt = TaskTable()
fbp_reconstructor = FBPReconstructor(ray_trafo)
cg_reconstructor = CGReconstructor(ray_trafo, ray_trafo.domain.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo, ray_trafo.domain.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo, ray_trafo.domain.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo, 0.5*ray_trafo.domain.one(), 1)
ista_reconstructor = ISTAReconstructor(ray_trafo,ray_trafo.domain.zero(), 10) # works
pdhg_reconstructor = PDHGReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # operand issue
dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
admm_reconstructor = ADMMReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works
bfgs_reconstructor = BFGSReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works

#,   pdhg_reconstructor, dougrach_reconstructor ,forwardbackward_reconstructor ]
                  # removed at 6-8 due to unsupported operand type(s) for +: 'MultiplyOperator' and 'DiscretizedSpaceElement'

reconstructors_tested = [fbp_reconstructor, admm_reconstructor, ista_reconstructor, mlem_reconstructor, gn_reconstructor, hybrid_model_u8, lw_reconstructor, hybrid_model_t0, hybrid_model_u15, hybrid_model_u25, hybrid_model_u35, hybrid_model_t4, cg_reconstructor, iradonmap_reconstructor, hybrid_model_t1, hybrid_model_t5] # , learnedpd_reconstructor, diptv_reconstructor - takes very long bfgs_reconstructor - negative PSNR

options = {'save_iterates': False, 'skip_training': True} # True original # addition skip_training': True

# eval_tt.append_all_combinations(reconstructors=reconstructors_tested, test_data= [test_data_ellipses], options=options)
eval_tt.append_all_combinations(reconstructors=reconstructors_tested, test_data=[test_data_10], options=options) # original

# testing one reconstructor
# eval_tt.append_all_combinations(reconstructors=[dougrach_reconstructor],
#                                 test_data=[test_data_shepp], options=options)

# %% run task table
results = eval_tt.run()
results.apply_measures([PSNR, SSIM])
# print(results.to_string)
print(results)

# # %% plot reconstructions
# fig = results.plot_all_reconstructions(fig_size=(9, 4), vrange='individual')

# # %% plot convergence of CG # comment out if testing one reconstructor
# results.plot_convergence(1, fig_size=(9, 6), gridspec_kw={'hspace': 0.5})

# # %% plot performance
# results.plot_performance(PSNR, figsize=(10, 4))
running task 0/16 ...
running task 1/16 ...
running task 2/16 ...
running task 3/16 ...
running task 4/16 ...
running task 5/16 ...
running task 6/16 ...
running task 7/16 ...
running task 8/16 ...
running task 9/16 ...
running task 10/16 ...
running task 11/16 ...
running task 12/16 ...
running task 13/16 ...
running task 14/16 ...
running task 15/16 ...
ResultTable(results=
                                                 reconstructor       test_data                     measure_values
task_ind sub_task_ind                                                                                            
0        0                                    FBPReconstructor  test part 0:10   mean: {psnr: 25.4, ssim: 0.4552}
1        0                                   ADMMReconstructor  test part 0:10  mean: {psnr: 11.6, ssim: 0.07373}
2        0                                   ISTAReconstructor  test part 0:10  mean: {psnr: 11.6, ssim: 0.07372}
3        0                                   MLEMReconstructor  test part 0:10  mean: {psnr: 17.93, ssim: 0.4488}
4        0                            GaussNewtonReconstructor  test part 0:10  mean: {psnr: 18.47, ssim: 0.4831}
5        0                  Hybrid UNet Residual model sigma 8  test part 0:10    mean: {psnr: 22.99, ssim: 0.48}
6        0                              LandweberReconstructor  test part 0:10  mean: {psnr: 23.01, ssim: 0.5603}
7        0             Hbrd Transf0-CT Progr Training 300K Itr  test part 0:10  mean: {psnr: 25.89, ssim: 0.4755}
8        0                 Hybrid UNet Residual model sigma 15  test part 0:10   mean: {psnr: 27.63, ssim: 0.618}
9        0                 Hybrid UNet Residual model sigma 25  test part 0:10  mean: {psnr: 29.35, ssim: 0.6853}
10       0                 Hybrid UNet Residual model sigma 35  test part 0:10   mean: {psnr: 30.14, ssim: 0.722}
11       0              Hbrd Transf4-CT Progr Training 24K Itr  test part 0:10  mean: {psnr: 27.24, ssim: 0.5433}
12       0                                     CGReconstructor  test part 0:10  mean: {psnr: 26.43, ssim: 0.6441}
13       0                              IRadonMapReconstructor  test part 0:10   mean: {psnr: 30.4, ssim: 0.7293}
14       0             Hbrd Transf1-CT retrn fixPatch 240K Itr  test part 0:10  mean: {psnr: 30.58, ssim: 0.7132}
15       0                      Hbrd Transf5-Pre_Trained_blind  test part 0:10  mean: {psnr: 25.89, ssim: 0.4772}
)
In [ ]:
ray_trafo = dataset.get_ray_trafo(impl=IMPL) # - for Lodopab dataset

eval_tt = TaskTable()
fbp_reconstructor = FBPReconstructor(ray_trafo)
cg_reconstructor = CGReconstructor(ray_trafo, ray_trafo.domain.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo, ray_trafo.domain.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo, ray_trafo.domain.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo, 0.5*ray_trafo.domain.one(), 1)
ista_reconstructor = ISTAReconstructor(ray_trafo,ray_trafo.domain.zero(), 10) # works
pdhg_reconstructor = PDHGReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # operand issue
dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
admm_reconstructor = ADMMReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works
bfgs_reconstructor = BFGSReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works


reconstructors_tested = [fbp_reconstructor, admm_reconstructor, ista_reconstructor, mlem_reconstructor, gn_reconstructor, hybrid_model_u8, lw_reconstructor, hybrid_model_t0, hybrid_model_u15, hybrid_model_u25, hybrid_model_u35, hybrid_model_t4, cg_reconstructor, iradonmap_reconstructor, hybrid_model_t1, hybrid_model_t5]

options = {'save_iterates': False, 'skip_training': True}
eval_tt.append_all_combinations(reconstructors=reconstructors_tested, test_data=[test_data_50], options=options)
results = eval_tt.run()
results.apply_measures([PSNR, SSIM])
print(results)
ResultTable(results=
                                                 reconstructor       test_data                     measure_values
task_ind sub_task_ind                                                                                            
0        0                                    FBPReconstructor  test part 0:50  mean: {psnr: 23.87, ssim: 0.4011}
1        0                                   ADMMReconstructor  test part 0:50  mean: {psnr: 10.2, ssim: 0.07581}
2        0                                   ISTAReconstructor  test part 0:50  mean: {psnr: 10.2, ssim: 0.07581}
3        0                                   MLEMReconstructor  test part 0:50  mean: {psnr: 17.09, ssim: 0.4496}
4        0                            GaussNewtonReconstructor  test part 0:50  mean: {psnr: 17.42, ssim: 0.4828}
5        0                  Hybrid UNet Residual model sigma 8  test part 0:50   mean: {psnr: 20.8, ssim: 0.4176}
6        0                              LandweberReconstructor  test part 0:50  mean: {psnr: 22.46, ssim: 0.5637}
7        0             Hbrd Transf0-CT Progr Training 300K Itr  test part 0:50   mean: {psnr: 24.31, ssim: 0.418}
8        0                 Hybrid UNet Residual model sigma 15  test part 0:50  mean: {psnr: 25.85, ssim: 0.5704}
9        0                 Hybrid UNet Residual model sigma 25  test part 0:50  mean: {psnr: 28.74, ssim: 0.6728}
10       0                 Hybrid UNet Residual model sigma 35  test part 0:50  mean: {psnr: 29.22, ssim: 0.7016}
11       0              Hbrd Transf4-CT Progr Training 24K Itr  test part 0:50  mean: {psnr: 25.75, ssim: 0.4853}
12       0                                     CGReconstructor  test part 0:50  mean: {psnr: 25.94, ssim: 0.6378}
13       0                              IRadonMapReconstructor  test part 0:50  mean: {psnr: 29.76, ssim: 0.7215}
14       0             Hbrd Transf1-CT retrn fixPatch 240K Itr  test part 0:50  mean: {psnr: 29.77, ssim: 0.6974}
)
In [ ]:
ray_trafo = dataset.get_ray_trafo(impl=IMPL) # - for Lodopab dataset

eval_tt = TaskTable()
fbp_reconstructor = FBPReconstructor(ray_trafo)
cg_reconstructor = CGReconstructor(ray_trafo, ray_trafo.domain.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo, ray_trafo.domain.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo, ray_trafo.domain.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo, 0.5*ray_trafo.domain.one(), 1)
ista_reconstructor = ISTAReconstructor(ray_trafo,ray_trafo.domain.zero(), 10) # works
pdhg_reconstructor = PDHGReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # operand issue
dougrach_reconstructor = DouglasRachfordReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
forwardbackward_reconstructor = ForwardBackwardReconstructor(ray_trafo,
                                                      ray_trafo.domain.zero(), 10) # operand issue
admm_reconstructor = ADMMReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works
bfgs_reconstructor = BFGSReconstructor(ray_trafo, ray_trafo.domain.zero(), 10) # works


reconstructors_tested = [fbp_reconstructor, admm_reconstructor, ista_reconstructor, mlem_reconstructor, gn_reconstructor, hybrid_model_u8, lw_reconstructor, hybrid_model_t0, hybrid_model_u15, hybrid_model_u25, hybrid_model_u35, hybrid_model_t4, cg_reconstructor, iradonmap_reconstructor, hybrid_model_t1, hybrid_model_t5]

options = {'save_iterates': False, 'skip_training': True}
eval_tt.append_all_combinations(reconstructors=reconstructors_tested, test_data=[test_data_all], options=options)
results = eval_tt.run()
results.apply_measures([PSNR, SSIM])
print(results)
running task 0/15 ...
running task 1/15 ...
running task 2/15 ...
In [27]:
oth_recons = [iradonmap_reconstructor, hybrid_model_t0, hybrid_model_t1, hybrid_model_t4, hybrid_model_t5, hybrid_model_u8,
             hybrid_model_u15, hybrid_model_u25, hybrid_model_u35, hybrid_model_u42]
In [ ]:
inference_by_mult_datasets_models( inf_published = True, oth_recons = oth_recons, dataset_list = [ dataset], test_data_list = [test_data_10] ) # have to config fbp ray trafo for all new models if we want to run inference for ellipses
running task 0/18 ...
running task 1/18 ...
running task 2/18 ...
running task 3/18 ...
running task 4/18 ...
running task 5/18 ...
running task 6/18 ...
running task 7/18 ...
running task 8/18 ...
running task 9/18 ...
running task 10/18 ...
running task 11/18 ...
running task 12/18 ...
running task 13/18 ...
running task 14/18 ...
running task 15/18 ...
running task 16/18 ...
running task 17/18 ...
ResultTable(results=
                                                 reconstructor       test_data                         measure_values
task_ind sub_task_ind                                                                                                
0        0                                    FBPReconstructor  test part 0:10       mean: {psnr: 25.4, ssim: 0.4552}
1        0                            GaussNewtonReconstructor  test part 0:10      mean: {psnr: 18.47, ssim: 0.4831}
2        0                                   ISTAReconstructor  test part 0:10      mean: {psnr: 11.6, ssim: 0.07372}
3        0                                     CGReconstructor  test part 0:10      mean: {psnr: 26.43, ssim: 0.6441}
4        0                              LandweberReconstructor  test part 0:10      mean: {psnr: 23.01, ssim: 0.5603}
5        0                                   MLEMReconstructor  test part 0:10      mean: {psnr: 17.93, ssim: 0.4488}
6        0                                   ADMMReconstructor  test part 0:10      mean: {psnr: 11.6, ssim: 0.07373}
7        0                                   BFGSReconstructor  test part 0:10  mean: {psnr: -45.97, ssim: 5.417e-08}
8        0                              IRadonMapReconstructor  test part 0:10       mean: {psnr: 30.4, ssim: 0.7293}
9        0             Hbrd Transf0-CT Progr Training 300K Itr  test part 0:10      mean: {psnr: 25.89, ssim: 0.4755}
10       0             Hbrd Transf1-CT retrn fixPatch 240K Itr  test part 0:10      mean: {psnr: 30.58, ssim: 0.7132}
11       0              Hbrd Transf4-CT Progr Training 24K Itr  test part 0:10      mean: {psnr: 27.24, ssim: 0.5433}
12       0                      Hbrd Transf5-Pre_Trained_blind  test part 0:10      mean: {psnr: 25.89, ssim: 0.4772}
13       0                  Hybrid UNet Residual model sigma 8  test part 0:10        mean: {psnr: 22.99, ssim: 0.48}
14       0                 Hybrid UNet Residual model sigma 15  test part 0:10       mean: {psnr: 27.63, ssim: 0.618}
15       0                 Hybrid UNet Residual model sigma 25  test part 0:10      mean: {psnr: 29.35, ssim: 0.6853}
16       0                 Hybrid UNet Residual model sigma 35  test part 0:10       mean: {psnr: 30.14, ssim: 0.722}
17       0                 Hybrid UNet Residual model sigma 35  test part 0:10      mean: {psnr: 29.52, ssim: 0.7064}
)
In [ ]:
inference_by_mult_datasets_models( inf_published = True, oth_recons = oth_recons, dataset_list = [ dataset], test_data_list = [test_data] )
running task 0/18 ...
running task 1/18 ...
running task 2/18 ...
running task 3/18 ...
running task 4/18 ...
running task 5/18 ...
running task 6/18 ...
running task 7/18 ...
running task 8/18 ...
running task 9/18 ...
running task 10/18 ...
running task 11/18 ...
running task 12/18 ...
running task 13/18 ...
running task 14/18 ...
running task 15/18 ...

end¶

In [ ]:
# Terminal cell

!free -h # check WSL2 RAM